Skip to content

Commit 35c0bb5

Browse files
committed
implement preserving env from host into vm in shell command
Signed-off-by: olalekan odukoya <[email protected]>
1 parent 96c7179 commit 35c0bb5

File tree

5 files changed

+372
-2
lines changed

5 files changed

+372
-2
lines changed

cmd/limactl/shell.go

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/sirupsen/logrus"
2020
"github.com/spf13/cobra"
2121

22+
"github.com/lima-vm/lima/v2/pkg/envutil"
2223
"github.com/lima-vm/lima/v2/pkg/instance"
2324
"github.com/lima-vm/lima/v2/pkg/ioutilx"
2425
"github.com/lima-vm/lima/v2/pkg/limayaml"
@@ -54,6 +55,7 @@ func newShellCommand() *cobra.Command {
5455
shellCmd.Flags().String("shell", "", "Shell interpreter, e.g. /bin/bash")
5556
shellCmd.Flags().String("workdir", "", "Working directory")
5657
shellCmd.Flags().Bool("reconnect", false, "Reconnect to the SSH session")
58+
shellCmd.Flags().Bool("preserve-env", false, "Propagate environment variables to the shell")
5759
return shellCmd
5860
}
5961

@@ -178,7 +180,23 @@ func shellAction(cmd *cobra.Command, args []string) error {
178180
} else {
179181
shell = shellescape.Quote(shell)
180182
}
181-
script := fmt.Sprintf("%s ; exec %s --login", changeDirCmd, shell)
183+
// Handle environment variable propagation
184+
var envPrefix string
185+
preserveEnv, err := cmd.Flags().GetBool("preserve-env")
186+
if err != nil {
187+
return err
188+
}
189+
if preserveEnv {
190+
filteredEnv := envutil.FilterEnvironment()
191+
if len(filteredEnv) > 0 {
192+
envPrefix = "env "
193+
for _, envVar := range filteredEnv {
194+
envPrefix += shellescape.Quote(envVar) + " "
195+
}
196+
}
197+
}
198+
199+
script := fmt.Sprintf("%s ; exec %s%s --login", changeDirCmd, envPrefix, shell)
182200
if len(args) > 1 {
183201
quotedArgs := make([]string, len(args[1:]))
184202
parsingEnv := true

cmd/nerdctl.lima

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/bin/sh
22
set -eu
3-
exec lima nerdctl "$@"
3+
exec limactl shell --preserve-env default nerdctl "$@"

pkg/envutil/envutil.go

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package envutil
5+
6+
import (
7+
"os"
8+
"slices"
9+
"strings"
10+
11+
"github.com/sirupsen/logrus"
12+
)
13+
14+
// defaultBlockList contains environment variables that should not be propagated by default.
15+
var defaultBlockList = []string{
16+
"BASH*",
17+
"DISPLAY",
18+
"DYLD_*",
19+
"EUID",
20+
"FPATH",
21+
"GID",
22+
"GROUP",
23+
"HOME",
24+
"HOSTNAME",
25+
"LD_*",
26+
"LOGNAME",
27+
"OLDPWD",
28+
"PATH",
29+
"PWD",
30+
"SHELL",
31+
"SHLVL",
32+
"SSH_*",
33+
"TERM",
34+
"TERMINFO",
35+
"TMPDIR",
36+
"UID",
37+
"USER",
38+
"XAUTHORITY",
39+
"XDG_*",
40+
"ZDOTDIR",
41+
"ZSH*",
42+
"_*", // Variables starting with underscore are typically internal
43+
}
44+
45+
func getBlockList() []string {
46+
blockEnv := os.Getenv("LIMA_SHELLENV_BLOCK")
47+
if blockEnv == "" {
48+
return defaultBlockList
49+
}
50+
after, found := strings.CutPrefix(blockEnv, "+")
51+
if !found {
52+
return parseEnvList(blockEnv)
53+
}
54+
return slices.Concat(defaultBlockList, parseEnvList(after))
55+
}
56+
57+
func getAllowList() []string {
58+
if allowEnv := os.Getenv("LIMA_SHELLENV_ALLOW"); allowEnv != "" {
59+
return parseEnvList(allowEnv)
60+
}
61+
return nil
62+
}
63+
64+
func parseEnvList(envList string) []string {
65+
parts := strings.Split(envList, ",")
66+
result := make([]string, 0, len(parts))
67+
for _, part := range parts {
68+
if trimmed := strings.TrimSpace(part); trimmed != "" {
69+
result = append(result, trimmed)
70+
}
71+
}
72+
73+
return result
74+
}
75+
76+
func matchesPattern(name, pattern string) bool {
77+
if pattern == name {
78+
return true
79+
}
80+
81+
prefix, found := strings.CutSuffix(pattern, "*")
82+
return found && strings.HasPrefix(name, prefix)
83+
}
84+
85+
func matchesAnyPattern(name string, patterns []string) bool {
86+
return slices.ContainsFunc(patterns, func(pattern string) bool {
87+
return matchesPattern(name, pattern)
88+
})
89+
}
90+
91+
// FilterEnvironment filters environment variables based on configuration from environment variables.
92+
// It returns a slice of environment variables that are not blocked by the current configuration.
93+
// The filtering is controlled by LIMA_SHELLENV_BLOCK and LIMA_SHELLENV_ALLOW environment variables.
94+
func FilterEnvironment() []string {
95+
return filterEnvironmentWithLists(
96+
os.Environ(),
97+
getBlockList(),
98+
getAllowList(),
99+
)
100+
}
101+
102+
func filterEnvironmentWithLists(env, blockList, allowList []string) []string {
103+
var filtered []string
104+
105+
for _, envVar := range env {
106+
parts := strings.SplitN(envVar, "=", 2)
107+
if len(parts) != 2 {
108+
continue
109+
}
110+
111+
name := parts[0]
112+
113+
if len(allowList) > 0 {
114+
if !matchesAnyPattern(name, allowList) {
115+
continue
116+
}
117+
filtered = append(filtered, envVar)
118+
continue
119+
}
120+
121+
if matchesAnyPattern(name, blockList) {
122+
logrus.Debugf("Blocked env variable %q", name)
123+
continue
124+
}
125+
126+
filtered = append(filtered, envVar)
127+
}
128+
129+
return filtered
130+
}
131+
132+
// GetDefaultBlockList returns a copy of the default block list.
133+
func GetDefaultBlockList() []string {
134+
return slices.Clone(defaultBlockList)
135+
}

pkg/envutil/envutil_test.go

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package envutil
5+
6+
import (
7+
"os"
8+
"slices"
9+
"strings"
10+
"testing"
11+
12+
"gotest.tools/v3/assert"
13+
)
14+
15+
func isUsingDefaultBlockList() bool {
16+
blockEnv := os.Getenv("LIMA_SHELLENV_BLOCK")
17+
return blockEnv == "" || strings.HasPrefix(blockEnv, "+")
18+
}
19+
20+
func TestMatchesPattern(t *testing.T) {
21+
tests := []struct {
22+
name string
23+
pattern string
24+
expected bool
25+
}{
26+
{"PATH", "PATH", true},
27+
{"PATH", "HOME", false},
28+
{"SSH_AUTH_SOCK", "SSH_*", true},
29+
{"SSH_AGENT_PID", "SSH_*", true},
30+
{"HOME", "SSH_*", false},
31+
{"XDG_CONFIG_HOME", "XDG_*", true},
32+
{"_LIMA_TEST", "_*", true},
33+
{"LIMA_HOME", "_*", false},
34+
}
35+
36+
for _, tt := range tests {
37+
t.Run(tt.name+"_matches_"+tt.pattern, func(t *testing.T) {
38+
result := matchesPattern(tt.name, tt.pattern)
39+
assert.Equal(t, result, tt.expected)
40+
})
41+
}
42+
}
43+
44+
func TestMatchesAnyPattern(t *testing.T) {
45+
patterns := []string{"PATH", "SSH_*", "XDG_*"}
46+
47+
tests := []struct {
48+
name string
49+
expected bool
50+
}{
51+
{"PATH", true},
52+
{"HOME", false},
53+
{"SSH_AUTH_SOCK", true},
54+
{"XDG_CONFIG_HOME", true},
55+
{"USER", false},
56+
}
57+
58+
for _, tt := range tests {
59+
t.Run(tt.name, func(t *testing.T) {
60+
result := matchesAnyPattern(tt.name, patterns)
61+
assert.Equal(t, result, tt.expected)
62+
})
63+
}
64+
}
65+
66+
func TestParseEnvList(t *testing.T) {
67+
tests := []struct {
68+
input string
69+
expected []string
70+
}{
71+
{"", []string{}},
72+
{"PATH", []string{"PATH"}},
73+
{"PATH,HOME", []string{"PATH", "HOME"}},
74+
{"PATH, HOME , USER", []string{"PATH", "HOME", "USER"}},
75+
{" , , ", []string{}},
76+
}
77+
78+
for _, tt := range tests {
79+
t.Run(tt.input, func(t *testing.T) {
80+
result := parseEnvList(tt.input)
81+
assert.DeepEqual(t, result, tt.expected)
82+
})
83+
}
84+
}
85+
86+
func TestGetBlockAndAllowLists(t *testing.T) {
87+
originalBlock := os.Getenv("LIMA_SHELLENV_BLOCK")
88+
originalAllow := os.Getenv("LIMA_SHELLENV_ALLOW")
89+
defer func() {
90+
if originalBlock != "" {
91+
t.Setenv("LIMA_SHELLENV_BLOCK", originalBlock)
92+
}
93+
94+
if originalAllow != "" {
95+
t.Setenv("LIMA_SHELLENV_ALLOW", originalAllow)
96+
}
97+
}()
98+
99+
t.Run("default config", func(t *testing.T) {
100+
t.Setenv("LIMA_SHELLENV_BLOCK", "")
101+
t.Setenv("LIMA_SHELLENV_ALLOW", "")
102+
103+
blockList := getBlockList()
104+
allowList := getAllowList()
105+
106+
assert.Assert(t, isUsingDefaultBlockList())
107+
assert.DeepEqual(t, blockList, defaultBlockList)
108+
assert.Equal(t, len(allowList), 0)
109+
})
110+
111+
t.Run("custom blocklist", func(t *testing.T) {
112+
t.Setenv("LIMA_SHELLENV_BLOCK", "PATH,HOME")
113+
t.Setenv("LIMA_SHELLENV_ALLOW", "")
114+
115+
blockList := getBlockList()
116+
assert.Assert(t, !isUsingDefaultBlockList())
117+
expected := []string{"PATH", "HOME"}
118+
assert.DeepEqual(t, blockList, expected)
119+
})
120+
121+
t.Run("additive blocklist", func(t *testing.T) {
122+
t.Setenv("LIMA_SHELLENV_BLOCK", "+CUSTOM_VAR")
123+
t.Setenv("LIMA_SHELLENV_ALLOW", "")
124+
125+
blockList := getBlockList()
126+
assert.Assert(t, isUsingDefaultBlockList())
127+
expected := slices.Concat(GetDefaultBlockList(), []string{"CUSTOM_VAR"})
128+
assert.DeepEqual(t, blockList, expected)
129+
})
130+
131+
t.Run("allowlist", func(t *testing.T) {
132+
t.Setenv("LIMA_SHELLENV_BLOCK", "")
133+
t.Setenv("LIMA_SHELLENV_ALLOW", "FOO,BAR")
134+
135+
allowList := getAllowList()
136+
expected := []string{"FOO", "BAR"}
137+
assert.DeepEqual(t, allowList, expected)
138+
})
139+
}
140+
141+
func TestFilterEnvironment(t *testing.T) {
142+
testEnv := []string{
143+
"PATH=/usr/bin",
144+
"HOME=/home/user",
145+
"USER=testuser",
146+
"FOO=bar",
147+
"SSH_AUTH_SOCK=/tmp/ssh",
148+
"XDG_CONFIG_HOME=/config",
149+
"BASH_VERSION=5.0",
150+
"_INTERNAL=secret",
151+
"CUSTOM_VAR=value",
152+
}
153+
154+
t.Run("default blocklist", func(t *testing.T) {
155+
result := filterEnvironmentWithLists(testEnv, defaultBlockList, nil)
156+
157+
expected := []string{"FOO=bar", "CUSTOM_VAR=value"}
158+
assert.Assert(t, containsAll(result, expected))
159+
160+
blocked := []string{
161+
"PATH=/usr/bin",
162+
"HOME=/home/user",
163+
"SSH_AUTH_SOCK=/tmp/ssh",
164+
"XDG_CONFIG_HOME=/config",
165+
"BASH_VERSION=5.0",
166+
"_INTERNAL=secret",
167+
}
168+
for _, blockedVar := range blocked {
169+
assert.Assert(t, !slices.Contains(result, blockedVar), "Expected result to not contain blocked variable %q", blockedVar)
170+
}
171+
})
172+
173+
t.Run("custom blocklist", func(t *testing.T) {
174+
result := filterEnvironmentWithLists(testEnv, []string{"FOO"}, nil)
175+
176+
assert.Assert(t, !slices.Contains(result, "FOO=bar"))
177+
178+
expected := []string{"PATH=/usr/bin", "HOME=/home/user", "USER=testuser"}
179+
assert.Assert(t, containsAll(result, expected))
180+
})
181+
182+
t.Run("allowlist", func(t *testing.T) {
183+
result := filterEnvironmentWithLists(testEnv, nil, []string{"FOO", "USER"})
184+
185+
expected := []string{"FOO=bar", "USER=testuser"}
186+
assert.Equal(t, len(result), len(expected))
187+
assert.Assert(t, containsAll(result, expected))
188+
})
189+
}
190+
191+
func containsAll(slice, items []string) bool {
192+
for _, item := range items {
193+
if !slices.Contains(slice, item) {
194+
return false
195+
}
196+
}
197+
return true
198+
}
199+
200+
func TestGetDefaultBlockList(t *testing.T) {
201+
blocklist := GetDefaultBlockList()
202+
203+
if &blocklist[0] == &defaultBlockList[0] {
204+
t.Error("GetDefaultBlockList should return a copy, not the original slice")
205+
}
206+
207+
assert.DeepEqual(t, blocklist, defaultBlockList)
208+
209+
expectedItems := []string{"PATH", "HOME", "SSH_*"}
210+
for _, item := range expectedItems {
211+
found := slices.Contains(blocklist, item)
212+
assert.Assert(t, found, "Expected builtin blocklist to contain %q", item)
213+
}
214+
}

0 commit comments

Comments
 (0)