From 1b4935c325149d67450a97e3f4049fe5487330d7 Mon Sep 17 00:00:00 2001 From: Sanne Raymaekers Date: Fri, 19 Apr 2024 11:52:21 +0200 Subject: [PATCH] jobqueue: add channel to workers Stores the channel alongside the worker. --- internal/client/unit_test.go | 2 +- internal/jobqueue/fsjobqueue/fsjobqueue.go | 9 ++++++--- internal/jobqueue/jobqueuetest/jobqueuetest.go | 9 +++++---- internal/weldr/api_test.go | 4 ++-- internal/weldr/compose_test.go | 4 ++-- internal/worker/server.go | 17 ++++++++++++++--- pkg/jobqueue/dbjobqueue/dbjobqueue.go | 18 ++++++++++-------- .../schemas/008_workers_add_channel.sql | 2 ++ pkg/jobqueue/jobqueue.go | 7 ++++--- 9 files changed, 46 insertions(+), 26 deletions(-) create mode 100644 pkg/jobqueue/dbjobqueue/schemas/008_workers_add_channel.sql diff --git a/internal/client/unit_test.go b/internal/client/unit_test.go index 7c9facc95..3628daa68 100644 --- a/internal/client/unit_test.go +++ b/internal/client/unit_test.go @@ -65,7 +65,7 @@ func executeTests(m *testing.M) int { fixture := rpmmd_mock.BaseFixture(path.Join(tmpdir, "/jobs"), test_distro.TestDistro1Name, test_distro.TestArchName) defer fixture.StoreFixture.Cleanup() - _, err = fixture.Workers.RegisterWorker(fixture.StoreFixture.HostArchName) + _, err = fixture.Workers.RegisterWorker("", fixture.StoreFixture.HostArchName) if err != nil { panic(err) } diff --git a/internal/jobqueue/fsjobqueue/fsjobqueue.go b/internal/jobqueue/fsjobqueue/fsjobqueue.go index 348c2e723..938031903 100644 --- a/internal/jobqueue/fsjobqueue/fsjobqueue.go +++ b/internal/jobqueue/fsjobqueue/fsjobqueue.go @@ -62,6 +62,7 @@ type fsJobQueue struct { } type worker struct { + Channel string `json:"channel"` Arch string `json:"arch"` Heartbeat time.Time `json:"heartbeat"` Tokens map[uuid.UUID]struct{} @@ -467,12 +468,13 @@ func (q *fsJobQueue) RefreshHeartbeat(token uuid.UUID) { } } -func (q *fsJobQueue) InsertWorker(arch string) (uuid.UUID, error) { +func (q *fsJobQueue) InsertWorker(channel, arch string) (uuid.UUID, error) { q.mu.Lock() defer q.mu.Unlock() wID := uuid.New() q.workers[wID] = worker{ + Channel: channel, Arch: arch, Heartbeat: time.Now(), Tokens: make(map[uuid.UUID]struct{}), @@ -502,8 +504,9 @@ func (q *fsJobQueue) Workers(olderThan time.Duration) ([]jobqueue.Worker, error) for wID, w := range q.workers { if now.Sub(w.Heartbeat) > olderThan { workers = append(workers, jobqueue.Worker{ - ID: wID, - Arch: w.Arch, + ID: wID, + Channel: w.Channel, + Arch: w.Arch, }) } } diff --git a/internal/jobqueue/jobqueuetest/jobqueuetest.go b/internal/jobqueue/jobqueuetest/jobqueuetest.go index a558bd960..0da661d94 100644 --- a/internal/jobqueue/jobqueuetest/jobqueuetest.go +++ b/internal/jobqueue/jobqueuetest/jobqueuetest.go @@ -704,22 +704,23 @@ func test100dequeuers(t *testing.T, q jobqueue.JobQueue) { // Registers workers and runs jobs against them func testWorkers(t *testing.T, q jobqueue.JobQueue) { - one := pushTestJob(t, q, "octopus", nil, nil, "") + one := pushTestJob(t, q, "octopus", nil, nil, "chan") - w1, err := q.InsertWorker("x86_64") + w1, err := q.InsertWorker("chan", "x86_64") require.NoError(t, err) - w2, err := q.InsertWorker("aarch64") + w2, err := q.InsertWorker("chan", "aarch64") require.NoError(t, err) workers, err := q.Workers(0) require.NoError(t, err) require.Len(t, workers, 2) + require.Equal(t, "chan", workers[0].Channel) workers, err = q.Workers(time.Hour * 24) require.NoError(t, err) require.Len(t, workers, 0) - _, _, _, _, _, err = q.Dequeue(context.Background(), w1, []string{"octopus"}, []string{""}) + _, _, _, _, _, err = q.Dequeue(context.Background(), w1, []string{"octopus"}, []string{"chan"}) require.NoError(t, err) err = q.DeleteWorker(w1) diff --git a/internal/weldr/api_test.go b/internal/weldr/api_test.go index 823ba6e01..417896b5f 100644 --- a/internal/weldr/api_test.go +++ b/internal/weldr/api_test.go @@ -1338,7 +1338,7 @@ func TestCompose(t *testing.T) { api, sf := createTestWeldrAPI(t.TempDir(), test_distro.TestDistro1Name, test_distro.TestArchName, rpmmd_mock.NoComposesFixture, nil) t.Cleanup(sf.Cleanup) - _, err = api.workers.RegisterWorker(arch.Name()) + _, err = api.workers.RegisterWorker("", arch.Name()) require.NoError(t, err) test.TestRoute(t, api, c.External, c.Method, c.Path, c.Body, c.ExpectedStatus, c.ExpectedJSON, c.IgnoreFields...) @@ -2329,7 +2329,7 @@ func TestComposePOST_ImageTypeDenylist(t *testing.T) { t.Run(fmt.Sprintf("case %d", idx), func(t *testing.T) { api, sf := createTestWeldrAPI(t.TempDir(), distro2.Name(), arch.Name(), rpmmd_mock.NoComposesFixture, c.imageTypeDenylist) t.Cleanup(sf.Cleanup) - _, err = api.workers.RegisterWorker(arch.Name()) + _, err = api.workers.RegisterWorker("", arch.Name()) require.NoError(t, err) test.TestRoute(t, api, true, "POST", c.Path, c.Body, c.ExpectedStatus, c.ExpectedJSON, c.IgnoreFields...) diff --git a/internal/weldr/compose_test.go b/internal/weldr/compose_test.go index d80106f99..aba3a494c 100644 --- a/internal/weldr/compose_test.go +++ b/internal/weldr/compose_test.go @@ -43,7 +43,7 @@ func TestComposeStatusFromLegacyError(t *testing.T) { t.Fatalf("error serializing osbuild manifest: %v", err) } - _, err = api.workers.RegisterWorker(arch.Name()) + _, err = api.workers.RegisterWorker("", arch.Name()) require.NoError(t, err) jobId, err := api.workers.EnqueueOSBuild(arch.Name(), &worker.OSBuildJob{Manifest: mf}, "") require.NoError(t, err) @@ -99,7 +99,7 @@ func TestComposeStatusFromJobError(t *testing.T) { t.Fatalf("error serializing osbuild manifest: %v", err) } - _, err = api.workers.RegisterWorker(arch.Name()) + _, err = api.workers.RegisterWorker("", arch.Name()) require.NoError(t, err) jobId, err := api.workers.EnqueueOSBuild(arch.Name(), &worker.OSBuildJob{Manifest: mf}, "") require.NoError(t, err) diff --git a/internal/worker/server.go b/internal/worker/server.go index 530d18a4d..2782c16d3 100644 --- a/internal/worker/server.go +++ b/internal/worker/server.go @@ -823,8 +823,8 @@ func (s *Server) RequeueOrFinishJob(token uuid.UUID, maxRetries uint64, result j return nil } -func (s *Server) RegisterWorker(a string) (uuid.UUID, error) { - workerID, err := s.jobs.InsertWorker(a) +func (s *Server) RegisterWorker(c, a string) (uuid.UUID, error) { + workerID, err := s.jobs.InsertWorker(c, a) if err != nil { return uuid.Nil, err } @@ -1059,7 +1059,18 @@ func (h *apiHandlers) PostWorkers(ctx echo.Context) error { return err } - workerID, err := h.server.RegisterWorker(body.Arch) + var channel string + if h.server.config.JWTEnabled { + tenant, err := auth.GetFromClaims(ctx.Request().Context(), h.server.config.TenantProviderFields) + if err != nil { + return api.HTTPErrorWithInternal(api.ErrorTenantNotFound, err) + } + + // prefix the tenant to prevent collisions if support for specifying channels in a request is ever added + channel = "org-" + tenant + } + + workerID, err := h.server.RegisterWorker(channel, body.Arch) if err != nil { return api.HTTPErrorWithInternal(api.ErrorInsertingWorker, err) } diff --git a/pkg/jobqueue/dbjobqueue/dbjobqueue.go b/pkg/jobqueue/dbjobqueue/dbjobqueue.go index 53e329eb1..af33d5660 100644 --- a/pkg/jobqueue/dbjobqueue/dbjobqueue.go +++ b/pkg/jobqueue/dbjobqueue/dbjobqueue.go @@ -117,14 +117,14 @@ const ( WHERE worker_id = $1` sqlInsertWorker = ` - INSERT INTO workers(worker_id, arch, heartbeat) - VALUES($1, $2, now())` + INSERT INTO workers(worker_id, channel, arch, heartbeat) + VALUES($1, $2, $3, now())` sqlUpdateWorkerStatus = ` UPDATE workers SET heartbeat = now() WHERE worker_id = $1` sqlQueryWorkers = ` - SELECT worker_id, arch + SELECT worker_id, channel, arch FROM workers WHERE age(now(), heartbeat) > $1` sqlDeleteWorker = ` @@ -713,7 +713,7 @@ func (q *DBJobQueue) RefreshHeartbeat(token uuid.UUID) { } } -func (q *DBJobQueue) InsertWorker(arch string) (uuid.UUID, error) { +func (q *DBJobQueue) InsertWorker(channel, arch string) (uuid.UUID, error) { conn, err := q.pool.Acquire(context.Background()) if err != nil { return uuid.Nil, err @@ -721,7 +721,7 @@ func (q *DBJobQueue) InsertWorker(arch string) (uuid.UUID, error) { defer conn.Release() id := uuid.New() - _, err = conn.Exec(context.Background(), sqlInsertWorker, id, arch) + _, err = conn.Exec(context.Background(), sqlInsertWorker, id, channel, arch) if err != nil { q.logger.Error(err, "Error inserting worker") return uuid.Nil, err @@ -763,16 +763,18 @@ func (q *DBJobQueue) Workers(olderThan time.Duration) ([]jobqueue.Worker, error) workers := make([]jobqueue.Worker, 0) for rows.Next() { var w uuid.UUID + var c string var a string - err = rows.Scan(&w, &a) + err = rows.Scan(&w, &c, &a) if err != nil { // Log the error and try to continue with the next row q.logger.Error(err, "Unable to read token from heartbeats") continue } workers = append(workers, jobqueue.Worker{ - ID: w, - Arch: a, + ID: w, + Channel: c, + Arch: a, }) } if rows.Err() != nil { diff --git a/pkg/jobqueue/dbjobqueue/schemas/008_workers_add_channel.sql b/pkg/jobqueue/dbjobqueue/schemas/008_workers_add_channel.sql new file mode 100644 index 000000000..82359c94c --- /dev/null +++ b/pkg/jobqueue/dbjobqueue/schemas/008_workers_add_channel.sql @@ -0,0 +1,2 @@ +ALTER TABLE workers +ADD COLUMN channel varchar NOT NULL DEFAULT ''; diff --git a/pkg/jobqueue/jobqueue.go b/pkg/jobqueue/jobqueue.go index 5c50e874b..cb88ee39a 100644 --- a/pkg/jobqueue/jobqueue.go +++ b/pkg/jobqueue/jobqueue.go @@ -82,7 +82,7 @@ type JobQueue interface { RefreshHeartbeat(token uuid.UUID) // Inserts the worker and creates a UUID for it - InsertWorker(arch string) (uuid.UUID, error) + InsertWorker(channel, arch string) (uuid.UUID, error) // Reset the last worker's heartbeat time to time.Now() UpdateWorkerStatus(workerID uuid.UUID) error @@ -117,6 +117,7 @@ var ( ) type Worker struct { - ID uuid.UUID - Arch string + ID uuid.UUID + Channel string + Arch string }