diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go index 92147d5..efa27d3 100644 --- a/internal/testutil/testutil.go +++ b/internal/testutil/testutil.go @@ -110,3 +110,19 @@ func CaptureStdio(t *testing.T, f func()) (string, string) { wg.Wait() return stdout.String(), stderr.String() } + +func Chdir(t *testing.T, dir string, f func()) { + cwd, err := os.Getwd() + if err != nil { + t.Errorf("%s", err.Error()) + } + defer func() { + os.Chdir(cwd) + }() + + err = os.Chdir(dir) + if err != nil { + t.Errorf("%s", err.Error()) + } + f() +} diff --git a/internal/testutil/testutil_test.go b/internal/testutil/testutil_test.go index 0162281..3067946 100644 --- a/internal/testutil/testutil_test.go +++ b/internal/testutil/testutil_test.go @@ -34,3 +34,15 @@ func TestCaptureStdout(t *testing.T) { assert.Equal(t, "output on stdout", stdout) assert.Equal(t, "output on stderr", stderr) } + +func TestChroot(t *testing.T) { + tmpdir := t.TempDir() + testutil.Chdir(t, tmpdir, func() { + cwd, err := os.Getwd() + assert.NoError(t, err) + assert.Equal(t, tmpdir, cwd) + }) + cwd, err := os.Getwd() + assert.NoError(t, err) + assert.NotEqual(t, tmpdir, cwd) +}