Skip to content
80 changes: 46 additions & 34 deletions internal/provider/template_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -1106,47 +1106,59 @@ func waitForJob(ctx context.Context, client *codersdk.Client, version *codersdk.
const maxRetries = 3
var jobLogs []codersdk.ProvisionerJobLog
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this variable here now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, removed it. waitForJobOnce now manages its own jobLogs internally and the caller accumulates via append(allLogs, logs...).

for retries := 0; retries < maxRetries; retries++ {
logs, closer, err := client.TemplateVersionLogsAfter(ctx, version.ID, 0)
var done bool
var err error
jobLogs, done, err = waitForJobOnce(ctx, client, version, jobLogs)
if err != nil {
return jobLogs, fmt.Errorf("begin streaming logs: %w", err)
return jobLogs, err
}
defer func() {
if err := closer.Close(); err != nil {
tflog.Warn(ctx, "error closing template version log stream", map[string]any{
"error": err,
})
}
}()
for {
logs, ok := <-logs
if !ok {
break
}
tflog.Info(ctx, logs.Output, map[string]interface{}{
"job_id": logs.ID,
"job_stage": logs.Stage,
"log_source": logs.Source,
"level": logs.Level,
"created_at": logs.CreatedAt,
})
if logs.Output != "" {
jobLogs = append(jobLogs, logs)
}
if done {
return jobLogs, nil
}
latestResp, err := client.TemplateVersion(ctx, version.ID)
if err != nil {
return jobLogs, err
}
return jobLogs, fmt.Errorf("provisioner job did not complete after %d retries", maxRetries)
}

func waitForJobOnce(ctx context.Context, client *codersdk.Client, version *codersdk.TemplateVersion, jobLogs []codersdk.ProvisionerJobLog) ([]codersdk.ProvisionerJobLog, bool, error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to pass in jobLogs here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call — removed the parameter. waitForJobOnce now returns only its own logs and the caller appends them.

logs, closer, err := client.TemplateVersionLogsAfter(ctx, version.ID, 0)
if err != nil {
return jobLogs, false, fmt.Errorf("begin streaming logs: %w", err)
}
defer func() {
if err := closer.Close(); err != nil {
tflog.Warn(ctx, "error closing template version log stream", map[string]any{
"error": err,
})
}
if latestResp.Job.Status.Active() {
tflog.Warn(ctx, fmt.Sprintf("provisioner job still active, continuing to wait...: %s", latestResp.Job.Status))
continue
}()
for {
logs, ok := <-logs
if !ok {
break
}
if latestResp.Job.Status != codersdk.ProvisionerJobSucceeded {
return jobLogs, fmt.Errorf("provisioner job did not succeed: %s (%s)", latestResp.Job.Status, latestResp.Job.Error)
tflog.Info(ctx, logs.Output, map[string]interface{}{
"job_id": logs.ID,
"job_stage": logs.Stage,
"log_source": logs.Source,
"level": logs.Level,
"created_at": logs.CreatedAt,
})
if logs.Output != "" {
jobLogs = append(jobLogs, logs)
}
return jobLogs, nil
}
return jobLogs, fmt.Errorf("provisioner job did not complete after %d retries", maxRetries)
latestResp, err := client.TemplateVersion(ctx, version.ID)
if err != nil {
return jobLogs, false, err
}
if latestResp.Job.Status.Active() {
tflog.Warn(ctx, fmt.Sprintf("provisioner job still active, continuing to wait...: %s", latestResp.Job.Status))
return jobLogs, false, nil
}
if latestResp.Job.Status != codersdk.ProvisionerJobSucceeded {
return jobLogs, false, fmt.Errorf("provisioner job did not succeed: %s (%s)", latestResp.Job.Status, latestResp.Job.Error)
}
return jobLogs, true, nil
}

type newVersionRequest struct {
Expand Down
222 changes: 222 additions & 0 deletions internal/provider/wait_for_job_test.go
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding tests! I tried checking to see if they trigger the panic by temporarily reverting your changes, and they don't seem to do so. Can you double check?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right — the existing tests don't reproduce the original v0.0.12 panic (defer before error check), since that was already fixed in v0.0.13.

I've added TestWaitForJob_ClosesConnectionBetweenRetries which specifically tests the defer-in-loop issue this PR fixes. It tracks the maximum number of concurrently open WebSocket connections during retries. With the old defer-in-loop code, closers accumulate and connections stay open across retries (maxOpenConns > 1). With the extracted function, each connection is closed before the next retry starts (maxOpenConns == 1).

Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
package provider

import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"sync/atomic"
"testing"

"github.com/coder/coder/v2/codersdk"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
)

func TestWaitForJobOnce_Success(t *testing.T) {
t.Parallel()
versionID := uuid.New()

handler := http.NewServeMux()
handler.HandleFunc("/api/v2/templateversions/", func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.RawQuery, "follow") {
conn, err := websocket.Accept(w, r, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
ctx := r.Context()
_ = wsjson.Write(ctx, conn, codersdk.ProvisionerJobLog{
ID: 1,
Output: "test log line",
})
_ = conn.Close(websocket.StatusNormalClosure, "done")
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(codersdk.TemplateVersion{
ID: versionID,
Job: codersdk.ProvisionerJob{
Status: codersdk.ProvisionerJobSucceeded,
},
})
})

srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
srvURL, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(srvURL)

version := &codersdk.TemplateVersion{ID: versionID}
logs, done, err := waitForJobOnce(context.Background(), client, version, nil)
require.NoError(t, err)
require.True(t, done)
require.Len(t, logs, 1)
require.Equal(t, "test log line", logs[0].Output)
}

func TestWaitForJobOnce_JobFailed(t *testing.T) {
t.Parallel()
versionID := uuid.New()

handler := http.NewServeMux()
handler.HandleFunc("/api/v2/templateversions/", func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.RawQuery, "follow") {
conn, err := websocket.Accept(w, r, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_ = conn.Close(websocket.StatusNormalClosure, "done")
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(codersdk.TemplateVersion{
ID: versionID,
Job: codersdk.ProvisionerJob{
Status: codersdk.ProvisionerJobFailed,
Error: "something went wrong",
},
})
})

srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
srvURL, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(srvURL)

version := &codersdk.TemplateVersion{ID: versionID}
_, done, err := waitForJobOnce(context.Background(), client, version, nil)
require.Error(t, err)
require.False(t, done)
require.Contains(t, err.Error(), "provisioner job did not succeed")
require.Contains(t, err.Error(), "something went wrong")
}

func TestWaitForJobOnce_StillActive(t *testing.T) {
t.Parallel()
versionID := uuid.New()

handler := http.NewServeMux()
handler.HandleFunc("/api/v2/templateversions/", func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.RawQuery, "follow") {
conn, err := websocket.Accept(w, r, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_ = conn.Close(websocket.StatusNormalClosure, "done")
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(codersdk.TemplateVersion{
ID: versionID,
Job: codersdk.ProvisionerJob{
Status: codersdk.ProvisionerJobRunning,
},
})
})

srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
srvURL, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(srvURL)

version := &codersdk.TemplateVersion{ID: versionID}
_, done, err := waitForJobOnce(context.Background(), client, version, nil)
require.NoError(t, err)
require.False(t, done)
}

func TestWaitForJob_RetriesAndCloses(t *testing.T) {
t.Parallel()
versionID := uuid.New()
var wsConnections atomic.Int32

handler := http.NewServeMux()
handler.HandleFunc("/api/v2/templateversions/", func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.RawQuery, "follow") {
wsConnections.Add(1)
conn, err := websocket.Accept(w, r, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
_ = conn.Close(websocket.StatusNormalClosure, "done")
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(codersdk.TemplateVersion{
ID: versionID,
Job: codersdk.ProvisionerJob{
Status: codersdk.ProvisionerJobRunning,
},
})
})

srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
srvURL, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(srvURL)

version := &codersdk.TemplateVersion{ID: versionID}
_, err = waitForJob(context.Background(), client, version)
require.Error(t, err)
require.Contains(t, err.Error(), "did not complete after 3 retries")
require.Equal(t, int32(3), wsConnections.Load())
}

func TestWaitForJob_SucceedsOnRetry(t *testing.T) {
t.Parallel()
versionID := uuid.New()
var versionCallCount atomic.Int32

handler := http.NewServeMux()
handler.HandleFunc("/api/v2/templateversions/", func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.RawQuery, "follow") {
conn, err := websocket.Accept(w, r, nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
ctx := r.Context()
_ = wsjson.Write(ctx, conn, codersdk.ProvisionerJobLog{
ID: int64(versionCallCount.Load()),
Output: "log line",
})
_ = conn.Close(websocket.StatusNormalClosure, "done")
return
}
count := versionCallCount.Add(1)
status := codersdk.ProvisionerJobRunning
if count >= 2 {
status = codersdk.ProvisionerJobSucceeded
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(codersdk.TemplateVersion{
ID: versionID,
Job: codersdk.ProvisionerJob{
Status: status,
},
})
})

srv := httptest.NewServer(handler)
t.Cleanup(srv.Close)
srvURL, err := url.Parse(srv.URL)
require.NoError(t, err)
client := codersdk.New(srvURL)

version := &codersdk.TemplateVersion{ID: versionID}
logs, err := waitForJob(context.Background(), client, version)
require.NoError(t, err)
require.Len(t, logs, 2)
}