diff --git a/internal/worker/client.go b/internal/worker/client.go index d35cc6713..32bc93224 100644 --- a/internal/worker/client.go +++ b/internal/worker/client.go @@ -12,9 +12,12 @@ import ( "net/http" "net/url" "sync" + "time" "github.com/google/uuid" + "github.com/sirupsen/logrus" + "github.com/osbuild/osbuild-composer/internal/common" "github.com/osbuild/osbuild-composer/internal/worker/api" ) @@ -26,8 +29,10 @@ type Client struct { accessToken string clientId string clientSecret string + workerID uuid.UUID - tokenMu sync.RWMutex + tokenMu sync.RWMutex + workerIDMu sync.RWMutex } type ClientConfig struct { @@ -49,7 +54,7 @@ type Job interface { NDynamicArgs() int Update(result interface{}) error Canceled() (bool, error) - UploadArtifact(name string, reader io.Reader) error + UploadArtifact(name string, readSeeker io.ReadSeeker) error } var ErrClientRequestJobTimeout = errors.New("Dequeue timed out, retry") @@ -108,14 +113,20 @@ func NewClient(conf ClientConfig) (*Client, error) { } requester.Transport = transport - return &Client{ + client := &Client{ server: server, requester: requester, offlineToken: conf.OfflineToken, oAuthURL: conf.OAuthURL, clientId: conf.ClientId, clientSecret: conf.ClientSecret, - }, nil + } + err = client.registerWorker() + if err != nil { + return client, err + } + go client.workerHeartbeat() + return client, nil } func NewClientUnix(conf ClientConfig) *Client { @@ -138,11 +149,103 @@ func NewClientUnix(conf ClientConfig) *Client { }, }, } - - return &Client{ + client := &Client{ server: server, requester: requester, } + err = client.registerWorker() + if err != nil { + panic(err) + } + go client.workerHeartbeat() + return client +} + +func (c *Client) registerWorker() error { + c.workerIDMu.Lock() + defer c.workerIDMu.Unlock() + + url, err := c.server.Parse("workers") + if err != nil { + return err + } + + var buf bytes.Buffer + err = json.NewEncoder(&buf).Encode(api.PostWorkersRequest{ + Arch: common.CurrentArch(), + }) + if err != nil { + logrus.Errorf("Unable create worker request: %v", err) + return err + } + + resp, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes())) + if err != nil { + logrus.Errorf("Unable to register worker: %v", err) + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + return errorFromResponse(resp, "error registering worker") + } + + var wresp api.PostWorkersResponse + err = json.NewDecoder(resp.Body).Decode(&wresp) + if err != nil { + return err + } + + workerID, err := uuid.Parse(wresp.WorkerId) + if err != nil { + return err + } + c.workerID = workerID + return nil +} + +func (c *Client) workerHeartbeat() { + //nolint:staticcheck // avoid SA1015, this is an endless function + for range time.Tick(time.Minute * 1) { + workerID := func() uuid.UUID { + c.workerIDMu.RLock() + defer c.workerIDMu.RUnlock() + return c.workerID + } + + if workerID() == uuid.Nil { + err := c.registerWorker() + if err != nil { + logrus.Errorf("Error registering worker, %v", err) + continue + } + } + + url, err := c.server.Parse(fmt.Sprintf("workers/%s/status", workerID())) + if err != nil { + logrus.Errorf("Error parsing worker status: %v", err) + continue + } + + var buf bytes.Buffer + resp, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes())) + if err != nil { + logrus.Errorf("Error updating worker status: %v", err) + continue + } + if resp.StatusCode == http.StatusBadRequest { + logrus.Warning("Registering worker again") + err = c.registerWorker() + if err != nil { + logrus.Errorf("Error registering worker, %v", err) + } + continue + } + if resp.StatusCode != http.StatusOK { + logrus.Errorf("Error updating worker status: %d", resp.StatusCode) + continue + } + } } func (c *Client) refreshAccessToken() error { @@ -176,26 +279,41 @@ func (c *Client) refreshAccessToken() error { return nil } -func (c *Client) NewRequest(method, url string, headers map[string]string, body io.Reader) (*http.Response, error) { +func (c *Client) NewRequest(method, url string, headers map[string]string, body io.ReadSeeker) (*http.Response, error) { token := func() string { c.tokenMu.RLock() defer c.tokenMu.RUnlock() return c.accessToken } - req, err := http.NewRequest(method, url, body) + newRequest := func() (*http.Request, error) { + if body != nil { + _, err := body.Seek(0, 0) + if err != nil { + return nil, err + } + } + + req, err := http.NewRequest(method, url, body) + if err != nil { + return nil, err + } + + if c.oAuthURL != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token())) + } + + for k, v := range headers { + req.Header.Add(k, v) + } + + return req, nil + } + + req, err := newRequest() if err != nil { return nil, err } - - if c.oAuthURL != "" { - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token())) - } - - for k, v := range headers { - req.Header.Add(k, v) - } - resp, err := c.requester.Do(req) if err != nil { return nil, err @@ -207,7 +325,10 @@ func (c *Client) NewRequest(method, url string, headers map[string]string, body return nil, err } - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token())) + req, err = newRequest() + if err != nil { + return nil, err + } resp, err = c.requester.Do(req) } return resp, err @@ -220,16 +341,28 @@ func (c *Client) RequestJob(types []string, arch string) (Job, error) { panic(err) } - var buf bytes.Buffer - err = json.NewEncoder(&buf).Encode(api.RequestJobJSONRequestBody{ + reqBody := api.RequestJobJSONRequestBody{ Types: types, Arch: arch, - }) + } + + workerID := func() uuid.UUID { + c.workerIDMu.RLock() + defer c.workerIDMu.RUnlock() + return c.workerID + }() + + if workerID != uuid.Nil { + reqBody.WorkerId = common.ToPtr(workerID.String()) + } + + var buf bytes.Buffer + err = json.NewEncoder(&buf).Encode(reqBody) if err != nil { panic(err) } - response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, &buf) + response, err := c.NewRequest("POST", url.String(), map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes())) if err != nil { return nil, err } @@ -320,7 +453,7 @@ func (j *job) Update(result interface{}) error { panic(err) } - response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, &buf) + response, err := j.client.NewRequest("PATCH", j.location, map[string]string{"Content-Type": "application/json"}, bytes.NewReader(buf.Bytes())) if err != nil { return fmt.Errorf("error fetching job info: %v", err) } @@ -353,7 +486,7 @@ func (j *job) Canceled() (bool, error) { return jr.Canceled, nil } -func (j *job) UploadArtifact(name string, reader io.Reader) error { +func (j *job) UploadArtifact(name string, readSeeker io.ReadSeeker) error { if j.artifactLocation == "" { return fmt.Errorf("server does not accept artifacts for this job") } @@ -368,7 +501,7 @@ func (j *job) UploadArtifact(name string, reader io.Reader) error { panic(err) } - response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, reader) + response, err := j.client.NewRequest("PUT", loc.String(), map[string]string{"Content-Type": "application/octet-stream"}, readSeeker) if err != nil { return fmt.Errorf("error uploading artifact: %v", err) } diff --git a/internal/worker/client_test.go b/internal/worker/client_test.go index 93a9b6b8f..78268e69d 100644 --- a/internal/worker/client_test.go +++ b/internal/worker/client_test.go @@ -128,11 +128,12 @@ func TestProxy(t *testing.T) { require.False(t, c) require.NoError(t, err) - // we expect 5 calls to go through the proxy: + // we expect 6 calls to go through the proxy: + // - register worker // - request job (fails, no oauth token) // - oauth call // - request job (succeeds) // - upload artifact // - cancel - require.Equal(t, 5, proxy.calls) + require.Equal(t, 6, proxy.calls) }