Skip to content

Commit 35d82ea

Browse files
committed
implement preserving env from host into vm in shell command
Signed-off-by: olalekan odukoya <[email protected]>
1 parent 96c7179 commit 35d82ea

File tree

5 files changed

+368
-2
lines changed

5 files changed

+368
-2
lines changed

cmd/limactl/shell.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/sirupsen/logrus"
2020
"github.com/spf13/cobra"
2121

22+
"github.com/lima-vm/lima/v2/pkg/envutil"
2223
"github.com/lima-vm/lima/v2/pkg/instance"
2324
"github.com/lima-vm/lima/v2/pkg/ioutilx"
2425
"github.com/lima-vm/lima/v2/pkg/limayaml"
@@ -54,6 +55,7 @@ func newShellCommand() *cobra.Command {
5455
shellCmd.Flags().String("shell", "", "Shell interpreter, e.g. /bin/bash")
5556
shellCmd.Flags().String("workdir", "", "Working directory")
5657
shellCmd.Flags().Bool("reconnect", false, "Reconnect to the SSH session")
58+
shellCmd.Flags().Bool("preserve-env", false, "Propagate environment variables to the shell")
5759
return shellCmd
5860
}
5961

@@ -178,7 +180,23 @@ func shellAction(cmd *cobra.Command, args []string) error {
178180
} else {
179181
shell = shellescape.Quote(shell)
180182
}
181-
script := fmt.Sprintf("%s ; exec %s --login", changeDirCmd, shell)
183+
// Handle environment variable propagation
184+
var envPrefix string
185+
preserveEnv, err := cmd.Flags().GetBool("preserve-env")
186+
if err != nil {
187+
return err
188+
}
189+
if preserveEnv {
190+
filteredEnv := envutil.FilterEnvironment()
191+
if len(filteredEnv) > 0 {
192+
envPrefix = "env "
193+
for _, envVar := range filteredEnv {
194+
envPrefix += shellescape.Quote(envVar) + " "
195+
}
196+
}
197+
}
198+
199+
script := fmt.Sprintf("%s ; exec %s%s --login", changeDirCmd, envPrefix, shell)
182200
if len(args) > 1 {
183201
quotedArgs := make([]string, len(args[1:]))
184202
parsingEnv := true

cmd/nerdctl.lima

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/bin/sh
22
set -eu
3-
exec lima nerdctl "$@"
3+
exec limactl shell --preserve-env default nerdctl "$@"

pkg/envutil/envutil.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package envutil
5+
6+
import (
7+
"os"
8+
"slices"
9+
"strings"
10+
11+
"github.com/sirupsen/logrus"
12+
)
13+
14+
// defaultBlockList contains environment variables that should not be propagated by default.
15+
var defaultBlockList = []string{
16+
"BASH*",
17+
"DISPLAY",
18+
"DYLD_*",
19+
"EUID",
20+
"FPATH",
21+
"GID",
22+
"GROUP",
23+
"HOME",
24+
"HOSTNAME",
25+
"LD_*",
26+
"LOGNAME",
27+
"OLDPWD",
28+
"PATH",
29+
"PWD",
30+
"SHELL",
31+
"SHLVL",
32+
"SSH_*",
33+
"TERM",
34+
"TERMINFO",
35+
"TMPDIR",
36+
"UID",
37+
"USER",
38+
"XAUTHORITY",
39+
"XDG_*",
40+
"ZDOTDIR",
41+
"ZSH*",
42+
"_*", // Variables starting with underscore are typically internal
43+
}
44+
45+
func getBlockList() []string {
46+
blockEnv := os.Getenv("LIMA_SHELLENV_BLOCK")
47+
if blockEnv == "" {
48+
return defaultBlockList
49+
}
50+
after, found := strings.CutPrefix(blockEnv, "+")
51+
if !found {
52+
return parseEnvList(blockEnv)
53+
}
54+
return slices.Concat(defaultBlockList, parseEnvList(after))
55+
}
56+
57+
func getAllowList() []string {
58+
if allowEnv := os.Getenv("LIMA_SHELLENV_ALLOW"); allowEnv != "" {
59+
return parseEnvList(allowEnv)
60+
}
61+
return nil
62+
}
63+
64+
func parseEnvList(envList string) []string {
65+
parts := strings.Split(envList, ",")
66+
result := make([]string, 0, len(parts))
67+
for _, part := range parts {
68+
if trimmed := strings.TrimSpace(part); trimmed != "" {
69+
result = append(result, trimmed)
70+
}
71+
}
72+
73+
return result
74+
}
75+
76+
func matchesPattern(name, pattern string) bool {
77+
if pattern == name {
78+
return true
79+
}
80+
81+
prefix, found := strings.CutSuffix(pattern, "*")
82+
return found && strings.HasPrefix(name, prefix)
83+
}
84+
85+
func matchesAnyPattern(name string, patterns []string) bool {
86+
return slices.ContainsFunc(patterns, func(pattern string) bool {
87+
return matchesPattern(name, pattern)
88+
})
89+
}
90+
91+
// FilterEnvironment filters environment variables based on configuration from environment variables.
92+
// It returns a slice of environment variables that are not blocked by the current configuration.
93+
// The filtering is controlled by LIMA_SHELLENV_BLOCK and LIMA_SHELLENV_ALLOW environment variables.
94+
func FilterEnvironment() []string {
95+
return filterEnvironmentWithLists(
96+
os.Environ(),
97+
getAllowList(),
98+
getBlockList(),
99+
)
100+
}
101+
102+
func filterEnvironmentWithLists(env, allowList, blockList []string) []string {
103+
var filtered []string
104+
105+
for _, envVar := range env {
106+
parts := strings.SplitN(envVar, "=", 2)
107+
if len(parts) != 2 {
108+
continue
109+
}
110+
111+
name := parts[0]
112+
113+
if len(allowList) > 0 {
114+
if !matchesAnyPattern(name, allowList) {
115+
continue
116+
}
117+
filtered = append(filtered, envVar)
118+
continue
119+
}
120+
121+
if matchesAnyPattern(name, blockList) {
122+
logrus.Debugf("Blocked env variable %q", name)
123+
continue
124+
}
125+
126+
filtered = append(filtered, envVar)
127+
}
128+
129+
return filtered
130+
}
131+
132+
// GetDefaultBlockList returns a copy of the default block list.
133+
func GetDefaultBlockList() []string {
134+
return slices.Clone(defaultBlockList)
135+
}

pkg/envutil/envutil_test.go

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package envutil
5+
6+
import (
7+
"os"
8+
"slices"
9+
"strings"
10+
"testing"
11+
12+
"gotest.tools/v3/assert"
13+
)
14+
15+
func isUsingDefaultBlockList() bool {
16+
blockEnv := os.Getenv("LIMA_SHELLENV_BLOCK")
17+
return blockEnv == "" || strings.HasPrefix(blockEnv, "+")
18+
}
19+
20+
func TestMatchesPattern(t *testing.T) {
21+
tests := []struct {
22+
name string
23+
pattern string
24+
expected bool
25+
}{
26+
{"PATH", "PATH", true},
27+
{"PATH", "HOME", false},
28+
{"SSH_AUTH_SOCK", "SSH_*", true},
29+
{"SSH_AGENT_PID", "SSH_*", true},
30+
{"HOME", "SSH_*", false},
31+
{"XDG_CONFIG_HOME", "XDG_*", true},
32+
{"_LIMA_TEST", "_*", true},
33+
{"LIMA_HOME", "_*", false},
34+
}
35+
36+
for _, tt := range tests {
37+
t.Run(tt.name+"_matches_"+tt.pattern, func(t *testing.T) {
38+
result := matchesPattern(tt.name, tt.pattern)
39+
assert.Equal(t, result, tt.expected)
40+
})
41+
}
42+
}
43+
44+
func TestMatchesAnyPattern(t *testing.T) {
45+
patterns := []string{"PATH", "SSH_*", "XDG_*"}
46+
47+
tests := []struct {
48+
name string
49+
expected bool
50+
}{
51+
{"PATH", true},
52+
{"HOME", false},
53+
{"SSH_AUTH_SOCK", true},
54+
{"XDG_CONFIG_HOME", true},
55+
{"USER", false},
56+
}
57+
58+
for _, tt := range tests {
59+
t.Run(tt.name, func(t *testing.T) {
60+
result := matchesAnyPattern(tt.name, patterns)
61+
assert.Equal(t, result, tt.expected)
62+
})
63+
}
64+
}
65+
66+
func TestParseEnvList(t *testing.T) {
67+
tests := []struct {
68+
input string
69+
expected []string
70+
}{
71+
{"", []string{}},
72+
{"PATH", []string{"PATH"}},
73+
{"PATH,HOME", []string{"PATH", "HOME"}},
74+
{"PATH, HOME , USER", []string{"PATH", "HOME", "USER"}},
75+
{" , , ", []string{}},
76+
}
77+
78+
for _, tt := range tests {
79+
t.Run(tt.input, func(t *testing.T) {
80+
result := parseEnvList(tt.input)
81+
assert.DeepEqual(t, result, tt.expected)
82+
})
83+
}
84+
}
85+
86+
func TestGetBlockAndAllowLists(t *testing.T) {
87+
t.Run("default config", func(t *testing.T) {
88+
t.Setenv("LIMA_SHELLENV_BLOCK", "")
89+
t.Setenv("LIMA_SHELLENV_ALLOW", "")
90+
91+
blockList := getBlockList()
92+
allowList := getAllowList()
93+
94+
assert.Assert(t, isUsingDefaultBlockList())
95+
assert.DeepEqual(t, blockList, defaultBlockList)
96+
assert.Equal(t, len(allowList), 0)
97+
})
98+
99+
t.Run("custom blocklist", func(t *testing.T) {
100+
t.Setenv("LIMA_SHELLENV_BLOCK", "PATH,HOME")
101+
102+
blockList := getBlockList()
103+
assert.Assert(t, !isUsingDefaultBlockList())
104+
expected := []string{"PATH", "HOME"}
105+
assert.DeepEqual(t, blockList, expected)
106+
})
107+
108+
t.Run("additive blocklist", func(t *testing.T) {
109+
t.Setenv("LIMA_SHELLENV_BLOCK", "+CUSTOM_VAR")
110+
111+
blockList := getBlockList()
112+
assert.Assert(t, isUsingDefaultBlockList())
113+
expected := slices.Concat(GetDefaultBlockList(), []string{"CUSTOM_VAR"})
114+
assert.DeepEqual(t, blockList, expected)
115+
})
116+
117+
t.Run("allowlist", func(t *testing.T) {
118+
t.Setenv("LIMA_SHELLENV_ALLOW", "FOO,BAR")
119+
120+
allowList := getAllowList()
121+
expected := []string{"FOO", "BAR"}
122+
assert.DeepEqual(t, allowList, expected)
123+
})
124+
}
125+
126+
func TestFilterEnvironment(t *testing.T) {
127+
testEnv := []string{
128+
"PATH=/usr/bin",
129+
"HOME=/home/user",
130+
"USER=testuser",
131+
"FOO=bar",
132+
"SSH_AUTH_SOCK=/tmp/ssh",
133+
"XDG_CONFIG_HOME=/config",
134+
"BASH_VERSION=5.0",
135+
"_INTERNAL=secret",
136+
"CUSTOM_VAR=value",
137+
}
138+
139+
t.Run("default blocklist", func(t *testing.T) {
140+
result := filterEnvironmentWithLists(testEnv, nil, defaultBlockList)
141+
142+
expected := []string{"FOO=bar", "CUSTOM_VAR=value"}
143+
assert.Assert(t, containsAll(result, expected))
144+
145+
blockedPrefixes := []string{
146+
"PATH=",
147+
"HOME=",
148+
"SSH_AUTH_SOCK=",
149+
"XDG_CONFIG_HOME=",
150+
"BASH_VERSION=",
151+
"_INTERNAL=",
152+
}
153+
for _, prefix := range blockedPrefixes {
154+
for _, envVar := range result {
155+
assert.Assert(t, !strings.HasPrefix(envVar, prefix), "Expected result to not contain variable with prefix %q, but found %q", prefix, envVar)
156+
}
157+
}
158+
})
159+
160+
t.Run("custom blocklist", func(t *testing.T) {
161+
result := filterEnvironmentWithLists(testEnv, nil, []string{"FOO"})
162+
163+
assert.Assert(t, !slices.Contains(result, "FOO=bar"))
164+
165+
expected := []string{"PATH=/usr/bin", "HOME=/home/user", "USER=testuser"}
166+
assert.Assert(t, containsAll(result, expected))
167+
})
168+
169+
t.Run("allowlist", func(t *testing.T) {
170+
result := filterEnvironmentWithLists(testEnv, []string{"FOO", "USER"}, nil)
171+
172+
expected := []string{"FOO=bar", "USER=testuser"}
173+
assert.Equal(t, len(result), len(expected))
174+
assert.Assert(t, containsAll(result, expected))
175+
})
176+
177+
t.Run("allowlist takes precedence over blocklist", func(t *testing.T) {
178+
result := filterEnvironmentWithLists(testEnv, []string{"FOO", "CUSTOM_VAR"}, []string{"FOO", "USER"})
179+
180+
expected := []string{"FOO=bar", "CUSTOM_VAR=value"}
181+
assert.Assert(t, containsAll(result, expected))
182+
183+
assert.Assert(t, !slices.Contains(result, "USER=testuser"))
184+
})
185+
}
186+
187+
func containsAll(slice, items []string) bool {
188+
for _, item := range items {
189+
if !slices.Contains(slice, item) {
190+
return false
191+
}
192+
}
193+
return true
194+
}
195+
196+
func TestGetDefaultBlockList(t *testing.T) {
197+
blocklist := GetDefaultBlockList()
198+
199+
if &blocklist[0] == &defaultBlockList[0] {
200+
t.Error("GetDefaultBlockList should return a copy, not the original slice")
201+
}
202+
203+
assert.DeepEqual(t, blocklist, defaultBlockList)
204+
205+
expectedItems := []string{"PATH", "HOME", "SSH_*"}
206+
for _, item := range expectedItems {
207+
found := slices.Contains(blocklist, item)
208+
assert.Assert(t, found, "Expected builtin blocklist to contain %q", item)
209+
}
210+
}

0 commit comments

Comments
 (0)