Skip to content

Commit 53e21f5

Browse files
committed
implement preserving env from host into vm in shell command
Signed-off-by: olalekan odukoya <[email protected]>
1 parent 96c7179 commit 53e21f5

File tree

5 files changed

+371
-2
lines changed

5 files changed

+371
-2
lines changed

cmd/limactl/shell.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/sirupsen/logrus"
2020
"github.com/spf13/cobra"
2121

22+
"github.com/lima-vm/lima/v2/pkg/envutil"
2223
"github.com/lima-vm/lima/v2/pkg/instance"
2324
"github.com/lima-vm/lima/v2/pkg/ioutilx"
2425
"github.com/lima-vm/lima/v2/pkg/limayaml"
@@ -54,6 +55,7 @@ func newShellCommand() *cobra.Command {
5455
shellCmd.Flags().String("shell", "", "Shell interpreter, e.g. /bin/bash")
5556
shellCmd.Flags().String("workdir", "", "Working directory")
5657
shellCmd.Flags().Bool("reconnect", false, "Reconnect to the SSH session")
58+
shellCmd.Flags().Bool("preserve-env", false, "Propagate environment variables to the shell")
5759
return shellCmd
5860
}
5961

@@ -178,7 +180,24 @@ func shellAction(cmd *cobra.Command, args []string) error {
178180
} else {
179181
shell = shellescape.Quote(shell)
180182
}
181-
script := fmt.Sprintf("%s ; exec %s --login", changeDirCmd, shell)
183+
// Handle environment variable propagation
184+
var envPrefix string
185+
preserveEnv, err := cmd.Flags().GetBool("preserve-env")
186+
if err != nil {
187+
return err
188+
}
189+
if preserveEnv {
190+
filteredEnv := envutil.FilterEnvironment()
191+
if len(filteredEnv) > 0 {
192+
envVars := make([]string, len(filteredEnv))
193+
for i, envVar := range filteredEnv {
194+
envVars[i] = shellescape.Quote(envVar)
195+
}
196+
envPrefix = "env " + strings.Join(envVars, " ") + " "
197+
}
198+
}
199+
200+
script := fmt.Sprintf("%s ; exec %s%s --login", changeDirCmd, envPrefix, shell)
182201
if len(args) > 1 {
183202
quotedArgs := make([]string, len(args[1:]))
184203
parsingEnv := true

cmd/nerdctl.lima

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/bin/sh
22
set -eu
3-
exec lima nerdctl "$@"
3+
exec limactl shell --preserve-env default nerdctl "$@"

pkg/envutil/envutil.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package envutil
5+
6+
import (
7+
"os"
8+
"slices"
9+
"strings"
10+
11+
"github.com/sirupsen/logrus"
12+
)
13+
14+
// DefaultBlockList contains environment variables that should not be propagated by default.
15+
var DefaultBlockList = []string{
16+
"BASH*",
17+
"DISPLAY",
18+
"DYLD_*",
19+
"EUID",
20+
"FPATH",
21+
"GID",
22+
"GROUP",
23+
"HOME",
24+
"HOSTNAME",
25+
"LD_*",
26+
"LOGNAME",
27+
"OLDPWD",
28+
"PATH",
29+
"PWD",
30+
"SHELL",
31+
"SHLVL",
32+
"SSH_*",
33+
"TERM",
34+
"TERMINFO",
35+
"TMPDIR",
36+
"UID",
37+
"USER",
38+
"XAUTHORITY",
39+
"XDG_*",
40+
"ZDOTDIR",
41+
"ZSH*",
42+
"_*", // Variables starting with underscore are typically internal
43+
}
44+
45+
func getBlockList() []string {
46+
blockEnv := os.Getenv("LIMA_SHELLENV_BLOCK")
47+
if blockEnv == "" {
48+
return DefaultBlockList
49+
}
50+
after, found := strings.CutPrefix(blockEnv, "+")
51+
if !found {
52+
return parseEnvList(blockEnv)
53+
}
54+
return slices.Concat(DefaultBlockList, parseEnvList(after))
55+
}
56+
57+
func getAllowList() []string {
58+
if allowEnv := os.Getenv("LIMA_SHELLENV_ALLOW"); allowEnv != "" {
59+
return parseEnvList(allowEnv)
60+
}
61+
return nil
62+
}
63+
64+
func IsUsingBuiltinBlockList() bool {
65+
blockEnv := os.Getenv("LIMA_SHELLENV_BLOCK")
66+
return blockEnv == "" || strings.HasPrefix(blockEnv, "+")
67+
}
68+
69+
func parseEnvList(envList string) []string {
70+
if envList == "" {
71+
return nil
72+
}
73+
74+
parts := strings.Split(envList, ",")
75+
result := make([]string, 0, len(parts))
76+
for _, part := range parts {
77+
if trimmed := strings.TrimSpace(part); trimmed != "" {
78+
result = append(result, trimmed)
79+
}
80+
}
81+
82+
return result
83+
}
84+
85+
func matchesPattern(name, pattern string) bool {
86+
if pattern == name {
87+
return true
88+
}
89+
90+
prefix, found := strings.CutSuffix(pattern, "*")
91+
return found && strings.HasPrefix(name, prefix)
92+
}
93+
94+
func matchesAnyPattern(name string, patterns []string) bool {
95+
return slices.ContainsFunc(patterns, func(pattern string) bool {
96+
return matchesPattern(name, pattern)
97+
})
98+
}
99+
100+
func FilterEnvironment() []string {
101+
return FilterEnvironmentWithLists(os.Environ(), getBlockList(), getAllowList())
102+
}
103+
104+
func FilterEnvironmentWithLists(env, blockList, allowList []string) []string {
105+
var filtered []string
106+
107+
for _, envVar := range env {
108+
parts := strings.SplitN(envVar, "=", 2)
109+
if len(parts) != 2 {
110+
continue
111+
}
112+
113+
name := parts[0]
114+
115+
if len(allowList) > 0 {
116+
if !matchesAnyPattern(name, allowList) {
117+
continue
118+
}
119+
filtered = append(filtered, envVar)
120+
continue
121+
}
122+
123+
if matchesAnyPattern(name, blockList) {
124+
logrus.Debugf("Blocked env variable %q", name)
125+
continue
126+
}
127+
128+
filtered = append(filtered, envVar)
129+
}
130+
131+
return filtered
132+
}
133+
134+
func GetBuiltinBlockList() []string {
135+
return slices.Clone(DefaultBlockList)
136+
}

pkg/envutil/envutil_test.go

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
// SPDX-FileCopyrightText: Copyright The Lima Authors
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package envutil
5+
6+
import (
7+
"os"
8+
"slices"
9+
"testing"
10+
11+
"gotest.tools/v3/assert"
12+
)
13+
14+
func TestMatchesPattern(t *testing.T) {
15+
tests := []struct {
16+
name string
17+
pattern string
18+
expected bool
19+
}{
20+
{"PATH", "PATH", true},
21+
{"PATH", "HOME", false},
22+
{"SSH_AUTH_SOCK", "SSH_*", true},
23+
{"SSH_AGENT_PID", "SSH_*", true},
24+
{"HOME", "SSH_*", false},
25+
{"XDG_CONFIG_HOME", "XDG_*", true},
26+
{"_LIMA_TEST", "_*", true},
27+
{"LIMA_HOME", "_*", false},
28+
}
29+
30+
for _, tt := range tests {
31+
t.Run(tt.name+"_matches_"+tt.pattern, func(t *testing.T) {
32+
result := matchesPattern(tt.name, tt.pattern)
33+
assert.Equal(t, result, tt.expected)
34+
})
35+
}
36+
}
37+
38+
func TestMatchesAnyPattern(t *testing.T) {
39+
patterns := []string{"PATH", "SSH_*", "XDG_*"}
40+
41+
tests := []struct {
42+
name string
43+
expected bool
44+
}{
45+
{"PATH", true},
46+
{"HOME", false},
47+
{"SSH_AUTH_SOCK", true},
48+
{"XDG_CONFIG_HOME", true},
49+
{"USER", false},
50+
}
51+
52+
for _, tt := range tests {
53+
t.Run(tt.name, func(t *testing.T) {
54+
result := matchesAnyPattern(tt.name, patterns)
55+
assert.Equal(t, result, tt.expected)
56+
})
57+
}
58+
}
59+
60+
func TestParseEnvList(t *testing.T) {
61+
tests := []struct {
62+
input string
63+
expected []string
64+
}{
65+
{"", nil},
66+
{"PATH", []string{"PATH"}},
67+
{"PATH,HOME", []string{"PATH", "HOME"}},
68+
{"PATH, HOME , USER", []string{"PATH", "HOME", "USER"}},
69+
{" , , ", []string{}},
70+
}
71+
72+
for _, tt := range tests {
73+
t.Run(tt.input, func(t *testing.T) {
74+
result := parseEnvList(tt.input)
75+
assert.DeepEqual(t, result, tt.expected)
76+
})
77+
}
78+
}
79+
80+
func TestGetBlockAndAllowLists(t *testing.T) {
81+
originalBlock := os.Getenv("LIMA_SHELLENV_BLOCK")
82+
originalAllow := os.Getenv("LIMA_SHELLENV_ALLOW")
83+
defer func() {
84+
if originalBlock != "" {
85+
t.Setenv("LIMA_SHELLENV_BLOCK", originalBlock)
86+
} else {
87+
os.Unsetenv("LIMA_SHELLENV_BLOCK")
88+
}
89+
if originalAllow != "" {
90+
t.Setenv("LIMA_SHELLENV_ALLOW", originalAllow)
91+
} else {
92+
os.Unsetenv("LIMA_SHELLENV_ALLOW")
93+
}
94+
}()
95+
96+
t.Run("default config", func(t *testing.T) {
97+
t.Setenv("LIMA_SHELLENV_BLOCK", "")
98+
t.Setenv("LIMA_SHELLENV_ALLOW", "")
99+
100+
blockList := getBlockList()
101+
allowList := getAllowList()
102+
103+
assert.Assert(t, IsUsingBuiltinBlockList())
104+
assert.DeepEqual(t, blockList, DefaultBlockList)
105+
assert.Equal(t, len(allowList), 0)
106+
})
107+
108+
t.Run("custom blocklist", func(t *testing.T) {
109+
t.Setenv("LIMA_SHELLENV_BLOCK", "PATH,HOME")
110+
t.Setenv("LIMA_SHELLENV_ALLOW", "")
111+
112+
blockList := getBlockList()
113+
assert.Assert(t, !IsUsingBuiltinBlockList())
114+
expected := []string{"PATH", "HOME"}
115+
assert.DeepEqual(t, blockList, expected)
116+
})
117+
118+
t.Run("additive blocklist", func(t *testing.T) {
119+
t.Setenv("LIMA_SHELLENV_BLOCK", "+CUSTOM_VAR")
120+
t.Setenv("LIMA_SHELLENV_ALLOW", "")
121+
122+
blockList := getBlockList()
123+
assert.Assert(t, IsUsingBuiltinBlockList())
124+
expected := slices.Concat(GetBuiltinBlockList(), []string{"CUSTOM_VAR"})
125+
assert.DeepEqual(t, blockList, expected)
126+
})
127+
128+
t.Run("allowlist", func(t *testing.T) {
129+
t.Setenv("LIMA_SHELLENV_BLOCK", "")
130+
t.Setenv("LIMA_SHELLENV_ALLOW", "FOO,BAR")
131+
132+
allowList := getAllowList()
133+
expected := []string{"FOO", "BAR"}
134+
assert.DeepEqual(t, allowList, expected)
135+
})
136+
}
137+
138+
func TestFilterEnvironment(t *testing.T) {
139+
testEnv := []string{
140+
"PATH=/usr/bin",
141+
"HOME=/home/user",
142+
"USER=testuser",
143+
"FOO=bar",
144+
"SSH_AUTH_SOCK=/tmp/ssh",
145+
"XDG_CONFIG_HOME=/config",
146+
"BASH_VERSION=5.0",
147+
"_INTERNAL=secret",
148+
"CUSTOM_VAR=value",
149+
}
150+
151+
t.Run("default blocklist", func(t *testing.T) {
152+
result := FilterEnvironmentWithLists(testEnv, DefaultBlockList, nil)
153+
154+
expected := []string{"FOO=bar", "CUSTOM_VAR=value"}
155+
assert.Assert(t, containsAll(result, expected))
156+
157+
blocked := []string{
158+
"PATH=/usr/bin",
159+
"HOME=/home/user",
160+
"SSH_AUTH_SOCK=/tmp/ssh",
161+
"XDG_CONFIG_HOME=/config",
162+
"BASH_VERSION=5.0",
163+
"_INTERNAL=secret",
164+
}
165+
for _, blockedVar := range blocked {
166+
assert.Assert(t, !slices.Contains(result, blockedVar), "Expected result to not contain blocked variable %q", blockedVar)
167+
}
168+
})
169+
170+
t.Run("custom blocklist", func(t *testing.T) {
171+
result := FilterEnvironmentWithLists(testEnv, []string{"FOO"}, nil)
172+
173+
assert.Assert(t, !slices.Contains(result, "FOO=bar"))
174+
175+
expected := []string{"PATH=/usr/bin", "HOME=/home/user", "USER=testuser"}
176+
assert.Assert(t, containsAll(result, expected))
177+
})
178+
179+
t.Run("allowlist", func(t *testing.T) {
180+
result := FilterEnvironmentWithLists(testEnv, nil, []string{"FOO", "USER"})
181+
182+
expected := []string{"FOO=bar", "USER=testuser"}
183+
assert.Equal(t, len(result), len(expected))
184+
assert.Assert(t, containsAll(result, expected))
185+
})
186+
}
187+
188+
func containsAll(slice, items []string) bool {
189+
for _, item := range items {
190+
if !slices.Contains(slice, item) {
191+
return false
192+
}
193+
}
194+
return true
195+
}
196+
197+
func TestGetBuiltinBlockList(t *testing.T) {
198+
blocklist := GetBuiltinBlockList()
199+
200+
if &blocklist[0] == &DefaultBlockList[0] {
201+
t.Error("GetBuiltinBlockList should return a copy, not the original slice")
202+
}
203+
204+
assert.DeepEqual(t, blocklist, DefaultBlockList)
205+
206+
expectedItems := []string{"PATH", "HOME", "SSH_*"}
207+
for _, item := range expectedItems {
208+
found := slices.Contains(blocklist, item)
209+
assert.Assert(t, found, "Expected builtin blocklist to contain %q", item)
210+
}
211+
}

0 commit comments

Comments
 (0)