diff --git a/cmd/osbuild-worker/export_test.go b/cmd/osbuild-worker/export_test.go index df64dcdfe..923eef7af 100644 --- a/cmd/osbuild-worker/export_test.go +++ b/cmd/osbuild-worker/export_test.go @@ -3,4 +3,13 @@ package main var ( WorkerClientErrorFrom = workerClientErrorFrom MakeJobErrorFromOsbuildOutput = makeJobErrorFromOsbuildOutput + Main = main ) + +func MockRun(new func()) (restore func()) { + saved := run + run = new + return func() { + run = saved + } +} diff --git a/cmd/osbuild-worker/main.go b/cmd/osbuild-worker/main.go index dc4eb6958..bd95e7115 100644 --- a/cmd/osbuild-worker/main.go +++ b/cmd/osbuild-worker/main.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "path" + "runtime/debug" "strings" "time" @@ -166,7 +167,7 @@ func RequestAndRunJob(client *worker.Client, acceptedJobTypes []string, jobImpls return nil } -func main() { +var run = func() { var unix bool flag.BoolVar(&unix, "unix", false, "Interpret 'address' as a path to a unix domain socket instead of a network address") @@ -534,3 +535,13 @@ func main() { } } } + +func main() { + defer func() { + if r := recover(); r != nil { + logrus.Fatalf("worker crashed: %s\n%s", r, debug.Stack()) + } + }() + + run() +} diff --git a/cmd/osbuild-worker/main_test.go b/cmd/osbuild-worker/main_test.go new file mode 100644 index 000000000..a695f0d92 --- /dev/null +++ b/cmd/osbuild-worker/main_test.go @@ -0,0 +1,41 @@ +package main_test + +import ( + "io" + "testing" + + "github.com/sirupsen/logrus" + logrusTest "github.com/sirupsen/logrus/hooks/test" + "github.com/stretchr/testify/assert" + + main "github.com/osbuild/osbuild-composer/cmd/osbuild-worker" +) + +func TestCatchesPanic(t *testing.T) { + restore := main.MockRun(func() { + // simulate a crash in the main code + var foo *int + println(*foo) + }) + defer restore() + + // logrus setup is a bit cumbersome as we need to modify both + // the standard logger and add a mock logger. + var exitCalls []int + logrus.StandardLogger().ExitFunc = func(exitCode int) { + exitCalls = append(exitCalls, exitCode) + } + logrus.SetOutput(io.Discard) + _, hook := logrusTest.NewNullLogger() + logrus.AddHook(hook) + + main.Main() + // ensure both message and stracktrace are reported in full + assert.Equal(t, logrus.FatalLevel, hook.LastEntry().Level) + msg := hook.LastEntry().Message + assert.Contains(t, msg, "worker crashed: runtime error: invalid memory address or nil pointer dereference") + assert.Contains(t, msg, "runtime/debug.Stack()") + assert.Contains(t, msg, "osbuild-worker_test.TestCatchesPanic.func1()") + + assert.Equal(t, []int{1}, exitCalls) +}