From ef1bb6f4ceae8c8c0c3279ebd5a96ef62fcd23f2 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 12 Aug 2025 21:20:59 +0000 Subject: [PATCH 1/5] chore: added support for immortal streams to cli and agent --- agent/agent.go | 50 ++++- cli/exp.go | 1 + cli/immortalstreams.go | 188 ++++++++++++++++++ cli/portforward.go | 17 ++ cli/ssh.go | 304 ++++++++++++++++++++++++++++- coderd/coderd.go | 5 + coderd/workspaceagents.go | 206 +++++++++++++++++++ codersdk/workspaceagents.go | 41 ++++ codersdk/workspacesdk/agentconn.go | 69 +++++++ 9 files changed, 871 insertions(+), 10 deletions(-) create mode 100644 cli/immortalstreams.go diff --git a/agent/agent.go b/agent/agent.go index 31b48edd4dc83..4cefcfa9f8616 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -70,6 +70,47 @@ const ( EnvProcOOMScore = "CODER_PROC_OOM_SCORE" ) +// agentImmortalDialer is a custom dialer for immortal streams that can +// connect to the agent's own services via tailnet addresses. +type agentImmortalDialer struct { + agent *agent + standardDialer *net.Dialer +} + +func (d *agentImmortalDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + host, portStr, err := net.SplitHostPort(address) + if err != nil { + return nil, xerrors.Errorf("split host port %q: %w", address, err) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, xerrors.Errorf("parse port %q: %w", portStr, err) + } + + // Check if this is a connection to one of the agent's own services + isLocalhost := host == "localhost" || host == "127.0.0.1" || host == "::1" + isAgentPort := port == int(workspacesdk.AgentSSHPort) || port == int(workspacesdk.AgentHTTPAPIServerPort) || + port == int(workspacesdk.AgentReconnectingPTYPort) || port == int(workspacesdk.AgentSpeedtestPort) + + if isLocalhost && isAgentPort { + // Get the agent ID from the current manifest + manifest := d.agent.manifest.Load() + if manifest == nil || manifest.AgentID == uuid.Nil { + // Fallback to standard dialing if no manifest available yet + return d.standardDialer.DialContext(ctx, network, address) + } + + // Connect to the agent's own tailnet address instead of localhost + agentAddr := tailnet.TailscaleServicePrefix.AddrFromUUID(manifest.AgentID) + agentAddress := net.JoinHostPort(agentAddr.String(), portStr) + return d.standardDialer.DialContext(ctx, network, agentAddress) + } + + // For other addresses, use standard dialing + return d.standardDialer.DialContext(ctx, network, address) +} + type Options struct { Filesystem afero.Fs LogDir string @@ -351,8 +392,13 @@ func (a *agent) init() { a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...) - // Initialize immortal streams manager - a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), &net.Dialer{}) + // Initialize immortal streams manager with a custom dialer + // that can connect to the agent's own services + immortalDialer := &agentImmortalDialer{ + agent: a, + standardDialer: &net.Dialer{}, + } + a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), immortalDialer) a.reconnectingPTYServer = reconnectingpty.NewServer( a.logger.Named("reconnecting-pty"), diff --git a/cli/exp.go b/cli/exp.go index e20d1e28d5ffe..65dea10e064f1 100644 --- a/cli/exp.go +++ b/cli/exp.go @@ -17,6 +17,7 @@ func (r *RootCmd) expCmd() *serpent.Command { r.promptExample(), r.rptyCommand(), r.tasksCommand(), + r.immortalStreamCmd(), }, } return cmd diff --git a/cli/immortalstreams.go b/cli/immortalstreams.go new file mode 100644 index 0000000000000..7dc3e0300d7ab --- /dev/null +++ b/cli/immortalstreams.go @@ -0,0 +1,188 @@ +package cli + +import ( + "context" + "fmt" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/cli/cliui" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/serpent" +) + +// immortalStreamClient provides methods to interact with immortal streams API +// This uses the main codersdk.Client to make server-proxied requests to agents +type immortalStreamClient struct { + client *codersdk.Client + agentID uuid.UUID + logger slog.Logger +} + +// newImmortalStreamClient creates a new client for immortal streams +func newImmortalStreamClient(client *codersdk.Client, agentID uuid.UUID, logger slog.Logger) *immortalStreamClient { + return &immortalStreamClient{ + client: client, + agentID: agentID, + logger: logger, + } +} + +// createStream creates a new immortal stream +func (c *immortalStreamClient) createStream(ctx context.Context, port int) (*codersdk.ImmortalStream, error) { + stream, err := c.client.WorkspaceAgentCreateImmortalStream(ctx, c.agentID, codersdk.CreateImmortalStreamRequest{ + TCPPort: port, + }) + if err != nil { + return nil, err + } + return &stream, nil +} + +// listStreams lists all immortal streams +func (c *immortalStreamClient) listStreams(ctx context.Context) ([]codersdk.ImmortalStream, error) { + return c.client.WorkspaceAgentImmortalStreams(ctx, c.agentID) +} + +// deleteStream deletes an immortal stream +func (c *immortalStreamClient) deleteStream(ctx context.Context, streamID uuid.UUID) error { + return c.client.WorkspaceAgentDeleteImmortalStream(ctx, c.agentID, streamID) +} + +// CLI Commands + +func (r *RootCmd) immortalStreamCmd() *serpent.Command { + client := new(codersdk.Client) + cmd := &serpent.Command{ + Use: "immortal-stream", + Short: "Manage immortal streams in workspaces", + Long: "Immortal streams provide persistent TCP connections to workspace services that automatically reconnect when interrupted.", + Middleware: serpent.Chain( + r.InitClient(client), + ), + Handler: func(inv *serpent.Invocation) error { + return inv.Command.HelpHandler(inv) + }, + Children: []*serpent.Command{ + r.immortalStreamListCmd(), + r.immortalStreamDeleteCmd(), + }, + } + return cmd +} + +func (r *RootCmd) immortalStreamListCmd() *serpent.Command { + client := new(codersdk.Client) + cmd := &serpent.Command{ + Use: "list ", + Short: "List active immortal streams in a workspace", + Middleware: serpent.Chain( + serpent.RequireNArgs(1), + r.InitClient(client), + ), + Handler: func(inv *serpent.Invocation) error { + ctx := inv.Context() + workspaceName := inv.Args[0] + + workspace, workspaceAgent, _, err := getWorkspaceAndAgent(ctx, inv, client, false, workspaceName) + if err != nil { + return err + } + + if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart { + return xerrors.New("workspace must be running to list immortal streams") + } + + // Create immortal stream client + // Note: We don't need to dial the agent for management operations + // as these go through the server's proxy endpoints + streamClient := newImmortalStreamClient(client, workspaceAgent.ID, inv.Logger) + streams, err := streamClient.listStreams(ctx) + if err != nil { + return xerrors.Errorf("list immortal streams: %w", err) + } + + if len(streams) == 0 { + cliui.Info(inv.Stderr, "No active immortal streams found.") + return nil + } + + // Display the streams in a table + displayImmortalStreams(inv, streams) + return nil + }, + } + return cmd +} + +func (r *RootCmd) immortalStreamDeleteCmd() *serpent.Command { + client := new(codersdk.Client) + cmd := &serpent.Command{ + Use: "delete ", + Short: "Delete an active immortal stream", + Middleware: serpent.Chain( + serpent.RequireNArgs(2), + r.InitClient(client), + ), + Handler: func(inv *serpent.Invocation) error { + ctx := inv.Context() + workspaceName := inv.Args[0] + streamName := inv.Args[1] + + workspace, workspaceAgent, _, err := getWorkspaceAndAgent(ctx, inv, client, false, workspaceName) + if err != nil { + return err + } + + if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart { + return xerrors.New("workspace must be running to delete immortal streams") + } + + // Create immortal stream client + streamClient := newImmortalStreamClient(client, workspaceAgent.ID, inv.Logger) + streams, err := streamClient.listStreams(ctx) + if err != nil { + return xerrors.Errorf("list immortal streams: %w", err) + } + + var targetStream *codersdk.ImmortalStream + for _, stream := range streams { + if stream.Name == streamName { + targetStream = &stream + break + } + } + + if targetStream == nil { + return xerrors.Errorf("immortal stream %q not found", streamName) + } + + // Delete the stream + err = streamClient.deleteStream(ctx, targetStream.ID) + if err != nil { + return xerrors.Errorf("delete immortal stream: %w", err) + } + + cliui.Info(inv.Stderr, fmt.Sprintf("Deleted immortal stream %q (ID: %s)", streamName, targetStream.ID)) + return nil + }, + } + return cmd +} + +func displayImmortalStreams(inv *serpent.Invocation, streams []codersdk.ImmortalStream) { + _, _ = fmt.Fprintf(inv.Stderr, "Active Immortal Streams:\n\n") + _, _ = fmt.Fprintf(inv.Stderr, "%-20s %-6s %-20s %-20s\n", "NAME", "PORT", "CREATED", "LAST CONNECTED") + _, _ = fmt.Fprintf(inv.Stderr, "%-20s %-6s %-20s %-20s\n", "----", "----", "-------", "--------------") + + for _, stream := range streams { + createdTime := stream.CreatedAt.Format("2006-01-02 15:04:05") + lastConnTime := stream.LastConnectionAt.Format("2006-01-02 15:04:05") + + _, _ = fmt.Fprintf(inv.Stderr, "%-20s %-6d %-20s %-20s\n", + stream.Name, stream.TCPPort, createdTime, lastConnTime) + } + _, _ = fmt.Fprintf(inv.Stderr, "\n") +} diff --git a/cli/portforward.go b/cli/portforward.go index 1b055d9e4362e..d96a0d697d289 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -39,6 +39,10 @@ func (r *RootCmd) portForward() *serpent.Command { udpForwards []string // : disableAutostart bool appearanceConfig codersdk.AppearanceConfig + + // Immortal streams flags + immortal bool + immortalFallback bool = true // Default to true for port-forward ) client := new(codersdk.Client) cmd := &serpent.Command{ @@ -212,6 +216,19 @@ func (r *RootCmd) portForward() *serpent.Command { Description: "Forward UDP port(s) from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols.", Value: serpent.StringArrayOf(&udpForwards), }, + { + Flag: "immortal", + Description: "Use immortal streams for port forwarding connections, providing automatic reconnection when interrupted.", + Value: serpent.BoolOf(&immortal), + Hidden: true, + }, + { + Flag: "immortal-fallback", + Description: "If immortal streams are unavailable due to connection limits, fall back to regular TCP connection.", + Default: "true", + Value: serpent.BoolOf(&immortalFallback), + Hidden: true, + }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } diff --git a/cli/ssh.go b/cli/ssh.go index a2f0db7327bef..8ce2a0420f172 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -48,6 +48,7 @@ import ( "github.com/coder/quartz" "github.com/coder/retry" "github.com/coder/serpent" + "github.com/coder/websocket" ) const ( @@ -85,6 +86,10 @@ func (r *RootCmd) ssh() *serpent.Command { containerName string containerUser string + + // Immortal streams flags + immortal bool + immortalFallback bool // Default to false for SSH ) client := new(codersdk.Client) wsClient := workspacesdk.New(client) @@ -387,11 +392,83 @@ func (r *RootCmd) ssh() *serpent.Command { } if stdio { - rawSSH, err := conn.SSH(ctx) - if err != nil { - return xerrors.Errorf("connect SSH: %w", err) + var rawSSH net.Conn + var immortalStreamClient *immortalStreamClient + var streamID *uuid.UUID + + if immortal { + // Use immortal stream for SSH connection + immortalStreamClient = newImmortalStreamClient(client, workspaceAgent.ID, logger) + + // Create immortal stream to agent SSH port (1) + stream, err := immortalStreamClient.createStream(ctx, 1) + if err != nil { + logger.Error(ctx, "failed to create immortal stream for SSH", + slog.Error(err), + slog.F("agent_id", workspaceAgent.ID), + slog.F("target_port", 1), + slog.F("workspace", workspace.Name), + slog.F("agent_status", workspaceAgent.Status), + slog.F("immortal_fallback_enabled", immortalFallback)) + + shouldFallback := immortalFallback && (strings.Contains(err.Error(), "too many immortal streams") || + strings.Contains(err.Error(), "The connection was refused")) + + if shouldFallback { + if strings.Contains(err.Error(), "too many immortal streams") { + logger.Warn(ctx, "too many immortal streams, falling back to regular SSH connection", + slog.F("max_streams", "32")) + } else { + logger.Warn(ctx, "Agent SSH service not available on port 1, falling back to regular SSH connection", + slog.F("reason", "connection_refused"), + slog.F("suggestion", "agent SSH server may not be running")) + } + logger.Info(ctx, "attempting fallback to regular SSH connection") + rawSSH, err = conn.SSH(ctx) + if err != nil { + logger.Error(ctx, "fallback SSH connection also failed", slog.Error(err)) + return xerrors.Errorf("connect SSH (fallback): %w", err) + } + logger.Info(ctx, "successfully connected via regular SSH fallback") + } else { + return xerrors.Errorf("create immortal stream for SSH: %w", err) + } + } else { + streamID = &stream.ID + logger.Info(ctx, "created immortal stream for SSH", slog.F("stream_name", stream.Name), slog.F("stream_id", stream.ID)) + + // Connect to the immortal stream via WebSocket + rawSSH, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) + if err != nil { + // Clean up the stream if connection fails + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + return xerrors.Errorf("connect to immortal stream: %w", err) + } + } + } else { + // Use regular SSH connection + rawSSH, err = conn.SSH(ctx) + if err != nil { + return xerrors.Errorf("connect SSH: %w", err) + } } - copier := newRawSSHCopier(logger, rawSSH, stdioReader, stdioWriter) + + var copier io.Closer + + if tcpConn, ok := rawSSH.(*gonet.TCPConn); ok { + // Use specialized raw SSH copier for regular TCP connections + rawCopier := newRawSSHCopier(logger, tcpConn, stdioReader, stdioWriter) + copier = rawCopier + // Start copying in the background for rawSSHCopier + go rawCopier.copy(&wg) + } else { + // Use generic copier for immortal stream connections + genericCopier := newGenericSSHCopier(logger, rawSSH, stdioReader, stdioWriter) + copier = genericCopier + // Start copying in the background for genericSSHCopier + go genericCopier.copy(&wg) + } + if err = stack.push("rawSSHCopier", copier); err != nil { return err } @@ -404,22 +481,108 @@ func (r *RootCmd) ssh() *serpent.Command { } } + // Set up cleanup for immortal stream + if immortalStreamClient != nil && streamID != nil { + defer func() { + if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { + logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err)) + } + }() + } + wg.Add(1) go func() { defer wg.Done() watchAndClose(ctx, func() error { + // Clean up immortal stream on termination + if immortalStreamClient != nil && streamID != nil { + if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { + logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err)) + } + } stack.close(xerrors.New("watchAndClose")) return nil }, logger, client, workspace, errCh) }() - copier.copy(&wg) + // The copying is already started in the background above + wg.Wait() return nil } - sshClient, err := conn.SSHClient(ctx) - if err != nil { - return xerrors.Errorf("ssh client: %w", err) + var sshClient *gossh.Client + var immortalStreamClient *immortalStreamClient + var streamID *uuid.UUID + + if immortal { + // Use immortal stream for SSH connection + immortalStreamClient = newImmortalStreamClient(client, workspaceAgent.ID, logger) + + // Create immortal stream to agent SSH port (1) + stream, err := immortalStreamClient.createStream(ctx, 1) + if err != nil { + logger.Error(ctx, "failed to create immortal stream for SSH (regular mode)", + slog.Error(err), + slog.F("agent_id", workspaceAgent.ID), + slog.F("target_port", 1), + slog.F("workspace", workspace.Name), + slog.F("agent_status", workspaceAgent.Status), + slog.F("immortal_fallback_enabled", immortalFallback)) + + shouldFallback := immortalFallback && (strings.Contains(err.Error(), "too many immortal streams") || + strings.Contains(err.Error(), "The connection was refused")) + + if shouldFallback { + if strings.Contains(err.Error(), "too many immortal streams") { + logger.Warn(ctx, "too many immortal streams, falling back to regular SSH connection", + slog.F("max_streams", "32")) + } else { + logger.Warn(ctx, "Agent SSH service not available on port 1, falling back to regular SSH connection", + slog.F("reason", "connection_refused"), + slog.F("suggestion", "agent SSH server may not be running")) + } + logger.Info(ctx, "attempting fallback to regular SSH client") + sshClient, err = conn.SSHClient(ctx) + if err != nil { + logger.Error(ctx, "fallback SSH client creation also failed", slog.Error(err)) + return xerrors.Errorf("ssh client (fallback): %w", err) + } + logger.Info(ctx, "successfully created SSH client via regular fallback") + } else { + return xerrors.Errorf("create immortal stream for SSH: %w", err) + } + } else { + streamID = &stream.ID + logger.Info(ctx, "created immortal stream for SSH", slog.F("stream_name", stream.Name), slog.F("stream_id", stream.ID)) + + // Connect to the immortal stream and create SSH client + rawConn, err := connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) + if err != nil { + // Clean up the stream if connection fails + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + return xerrors.Errorf("connect to immortal stream: %w", err) + } + + // Create SSH client over the immortal stream connection + sshConn, chans, reqs, err := gossh.NewClientConn(rawConn, "localhost:22", &gossh.ClientConfig{ + HostKeyCallback: gossh.InsecureIgnoreHostKey(), + Timeout: 30 * time.Second, + }) + if err != nil { + rawConn.Close() + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + return xerrors.Errorf("ssh handshake over immortal stream: %w", err) + } + + sshClient = gossh.NewClient(sshConn, chans, reqs) + } + } else { + // Use regular SSH connection + sshClient, err = conn.SSHClient(ctx) + if err != nil { + return xerrors.Errorf("ssh client: %w", err) + } } + if err = stack.push("ssh client", sshClient); err != nil { return err } @@ -440,12 +603,27 @@ func (r *RootCmd) ssh() *serpent.Command { } } + // Set up cleanup for immortal stream in regular SSH mode + if immortalStreamClient != nil && streamID != nil { + defer func() { + if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { + logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err)) + } + }() + } + wg.Add(1) go func() { defer wg.Done() watchAndClose( ctx, func() error { + // Clean up immortal stream on termination + if immortalStreamClient != nil && streamID != nil { + if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { + logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err)) + } + } stack.close(xerrors.New("watchAndClose")) return nil }, @@ -728,11 +906,83 @@ func (r *RootCmd) ssh() *serpent.Command { Value: serpent.BoolOf(&forceNewTunnel), Hidden: true, }, + { + Flag: "immortal", + Description: "Use immortal streams for SSH connection, providing automatic reconnection when interrupted.", + Value: serpent.BoolOf(&immortal), + Hidden: true, + }, + { + Flag: "immortal-fallback", + Description: "If immortal streams are unavailable due to connection limits, fall back to regular TCP connection.", + Value: serpent.BoolOf(&immortalFallback), + Hidden: true, + }, sshDisableAutostartOption(serpent.BoolOf(&disableAutostart)), } return cmd } +// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket and returns a net.Conn +func connectToImmortalStreamWebSocket(ctx context.Context, agentConn *workspacesdk.AgentConn, streamID uuid.UUID, logger slog.Logger) (net.Conn, error) { + // Build the target address for the agent's HTTP API server + // We'll let the WebSocket dialer handle the actual connection through the agent + apiServerAddr := fmt.Sprintf("127.0.0.1:%d", workspacesdk.AgentHTTPAPIServerPort) + wsURL := fmt.Sprintf("ws://%s/api/v0/immortal-stream/%s", apiServerAddr, streamID) + + // Create WebSocket connection using the agent's tailnet connection + // The key is to use a custom dialer that routes through the agent connection + dialOptions := &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: &http.Transport{ + DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) { + // Route all connections through the agent connection + // The agent connection will handle routing to the correct internal address + + conn, err := agentConn.DialContext(dialCtx, network, addr) + if err != nil { + return nil, err + } + + return conn, nil + }, + }, + }, + // Disable compression for raw TCP data + CompressionMode: websocket.CompressionDisabled, + } + + // Connect to the WebSocket endpoint + conn, res, err := websocket.Dial(ctx, wsURL, dialOptions) + if err != nil { + if res != nil { + logger.Error(ctx, "WebSocket dial failed", + slog.F("stream_id", streamID), + slog.F("websocket_url", wsURL), + slog.F("status", res.StatusCode), + slog.F("status_text", res.Status), + slog.Error(err)) + } else { + logger.Error(ctx, "WebSocket dial failed (no response)", + slog.F("stream_id", streamID), + slog.F("websocket_url", wsURL), + slog.Error(err)) + } + return nil, xerrors.Errorf("dial immortal stream WebSocket: %w", err) + } + + logger.Info(ctx, "successfully connected to immortal stream WebSocket", + slog.F("stream_id", streamID)) + + // Convert WebSocket to net.Conn for SSH usage + // Use MessageBinary for raw TCP data transport + netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary) + + logger.Debug(ctx, "converted WebSocket to net.Conn for SSH usage") + + return netConn, nil +} + // findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it // corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or // vscode-coder--myusername--myworkspace). @@ -1276,6 +1526,44 @@ func newRawSSHCopier(logger slog.Logger, conn *gonet.TCPConn, r io.Reader, w io. return &rawSSHCopier{conn: conn, logger: logger, r: r, w: w, done: make(chan struct{})} } +// genericSSHCopier is similar to rawSSHCopier but works with any net.Conn (e.g., immortal streams) +type genericSSHCopier struct { + conn net.Conn + logger slog.Logger + r io.Reader + w io.Writer + done chan struct{} +} + +func newGenericSSHCopier(logger slog.Logger, conn net.Conn, r io.Reader, w io.Writer) *genericSSHCopier { + return &genericSSHCopier{conn: conn, logger: logger, r: r, w: w, done: make(chan struct{})} +} + +func (c *genericSSHCopier) copy(wg *sync.WaitGroup) { + defer close(c.done) + + // Copy stdin to connection + go func() { + defer c.conn.Close() + _, err := io.Copy(c.conn, c.r) + if err != nil { + c.logger.Debug(context.Background(), "error copying stdin to connection", slog.Error(err)) + } + }() + + // Copy connection to stdout + _, err := io.Copy(c.w, c.conn) + if err != nil { + c.logger.Debug(context.Background(), "error copying connection to stdout", slog.Error(err)) + } +} + +func (c *genericSSHCopier) Close() error { + c.conn.Close() + <-c.done + return nil +} + func (c *rawSSHCopier) copy(wg *sync.WaitGroup) { defer close(c.done) logCtx := context.Background() diff --git a/coderd/coderd.go b/coderd/coderd.go index 724952bde7bb9..1254a735eb45b 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1403,6 +1403,11 @@ func New(options *Options) *API { r.Get("/containers/watch", api.watchWorkspaceAgentContainers) r.Post("/containers/devcontainers/{devcontainer}/recreate", api.workspaceAgentRecreateDevcontainer) r.Get("/coordinate", api.workspaceAgentClientCoordinate) + r.Route("/immortal-streams", func(r chi.Router) { + r.Get("/", api.workspaceAgentImmortalStreams) + r.Post("/", api.workspaceAgentCreateImmortalStream) + r.Delete("/{immortalstream}", api.workspaceAgentDeleteImmortalStream) + }) // PTY is part of workspaceAppServer. }) diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index f2ee1ac18e823..9da3c2fb85127 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -805,6 +805,212 @@ func (api *API) workspaceAgentListeningPorts(rw http.ResponseWriter, r *http.Req httpapi.Write(ctx, rw, http.StatusOK, portsResponse) } +// @Summary Get workspace agent immortal streams +// @ID get-workspace-agent-immortal-streams +// @Security CoderSessionToken +// @Produce json +// @Tags Agents +// @Param workspaceagent path string true "Workspace agent ID" format(uuid) +// @Success 200 {array} codersdk.ImmortalStream +// @Router /workspaceagents/{workspaceagent}/immortal-streams [get] +func (api *API) workspaceAgentImmortalStreams(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + workspaceAgent := httpmw.WorkspaceAgentParam(r) + + // Check agent connectivity with timeout + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + apiAgent, err := db2sdk.WorkspaceAgent( + api.DERPMap(), *api.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, api.AgentInactiveDisconnectTimeout, + api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), + ) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error reading workspace agent.", + Detail: err.Error(), + }) + return + } + if apiAgent.Status != codersdk.WorkspaceAgentConnected { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Agent state is %q, it must be in the %q state.", apiAgent.Status, codersdk.WorkspaceAgentConnected), + }) + return + } + + agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error dialing workspace agent.", + Detail: err.Error(), + }) + return + } + defer release() + + streams, err := agentConn.ImmortalStreams(ctx) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching immortal streams.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, streams) +} + +// @Summary Create workspace agent immortal stream +// @ID create-workspace-agent-immortal-stream +// @Security CoderSessionToken +// @Accept json +// @Produce json +// @Tags Agents +// @Param workspaceagent path string true "Workspace agent ID" format(uuid) +// @Param request body codersdk.CreateImmortalStreamRequest true "Create immortal stream request" +// @Success 201 {object} codersdk.ImmortalStream +// @Router /workspaceagents/{workspaceagent}/immortal-streams [post] +func (api *API) workspaceAgentCreateImmortalStream(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + workspaceAgent := httpmw.WorkspaceAgentParam(r) + + var req codersdk.CreateImmortalStreamRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + + // Check agent connectivity with timeout + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + apiAgent, err := db2sdk.WorkspaceAgent( + api.DERPMap(), *api.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, api.AgentInactiveDisconnectTimeout, + api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), + ) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error reading workspace agent.", + Detail: err.Error(), + }) + return + } + if apiAgent.Status != codersdk.WorkspaceAgentConnected { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Agent state is %q, it must be in the %q state.", apiAgent.Status, codersdk.WorkspaceAgentConnected), + }) + return + } + + agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error dialing workspace agent.", + Detail: err.Error(), + }) + return + } + defer release() + + stream, err := agentConn.CreateImmortalStream(ctx, req) + if err != nil { + // Check for specific error types from the agent + if strings.Contains(err.Error(), "too many immortal streams") { + httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{ + Message: "Too many immortal streams.", + }) + return + } + if strings.Contains(err.Error(), "connection was refused") { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ + Message: "The connection was refused.", + }) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error creating immortal stream.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusCreated, stream) +} + +// @Summary Delete workspace agent immortal stream +// @ID delete-workspace-agent-immortal-stream +// @Security CoderSessionToken +// @Tags Agents +// @Param workspaceagent path string true "Workspace agent ID" format(uuid) +// @Param immortalstream path string true "Immortal stream ID" format(uuid) +// @Success 200 {object} codersdk.Response +// @Router /workspaceagents/{workspaceagent}/immortal-streams/{immortalstream} [delete] +func (api *API) workspaceAgentDeleteImmortalStream(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + workspaceAgent := httpmw.WorkspaceAgentParam(r) + + streamIDStr := chi.URLParam(r, "immortalstream") + streamID, err := uuid.Parse(streamIDStr) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid immortal stream ID format.", + Detail: err.Error(), + }) + return + } + + // Check agent connectivity with timeout + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + apiAgent, err := db2sdk.WorkspaceAgent( + api.DERPMap(), *api.TailnetCoordinator.Load(), workspaceAgent, nil, nil, nil, api.AgentInactiveDisconnectTimeout, + api.DeploymentValues.AgentFallbackTroubleshootingURL.String(), + ) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error reading workspace agent.", + Detail: err.Error(), + }) + return + } + if apiAgent.Status != codersdk.WorkspaceAgentConnected { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: fmt.Sprintf("Agent state is %q, it must be in the %q state.", apiAgent.Status, codersdk.WorkspaceAgentConnected), + }) + return + } + + agentConn, release, err := api.agentProvider.AgentConn(ctx, workspaceAgent.ID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error dialing workspace agent.", + Detail: err.Error(), + }) + return + } + defer release() + + err = agentConn.DeleteImmortalStream(ctx, streamID) + if err != nil { + if strings.Contains(err.Error(), "stream not found") { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ + Message: "Immortal stream not found.", + }) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error deleting immortal stream.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, codersdk.Response{ + Message: "Immortal stream deleted successfully.", + }) +} + // @Summary Watch workspace agent for container updates. // @ID watch-workspace-agent-for-container-updates // @Security CoderSessionToken diff --git a/codersdk/workspaceagents.go b/codersdk/workspaceagents.go index 4f3faedb534fc..a0fdad857b412 100644 --- a/codersdk/workspaceagents.go +++ b/codersdk/workspaceagents.go @@ -393,6 +393,47 @@ func (c *Client) WorkspaceAgentListeningPorts(ctx context.Context, agentID uuid. return listeningPorts, json.NewDecoder(res.Body).Decode(&listeningPorts) } +// WorkspaceAgentImmortalStreams returns a list of immortal streams for the given agent. +func (c *Client) WorkspaceAgentImmortalStreams(ctx context.Context, agentID uuid.UUID) ([]ImmortalStream, error) { + res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/workspaceagents/%s/immortal-streams", agentID), nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, ReadBodyAsError(res) + } + var streams []ImmortalStream + return streams, json.NewDecoder(res.Body).Decode(&streams) +} + +// WorkspaceAgentCreateImmortalStream creates a new immortal stream for the given agent. +func (c *Client) WorkspaceAgentCreateImmortalStream(ctx context.Context, agentID uuid.UUID, req CreateImmortalStreamRequest) (ImmortalStream, error) { + res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/v2/workspaceagents/%s/immortal-streams", agentID), req) + if err != nil { + return ImmortalStream{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return ImmortalStream{}, ReadBodyAsError(res) + } + var stream ImmortalStream + return stream, json.NewDecoder(res.Body).Decode(&stream) +} + +// WorkspaceAgentDeleteImmortalStream deletes an immortal stream for the given agent. +func (c *Client) WorkspaceAgentDeleteImmortalStream(ctx context.Context, agentID uuid.UUID, streamID uuid.UUID) error { + res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/v2/workspaceagents/%s/immortal-streams/%s", agentID, streamID), nil) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return ReadBodyAsError(res) + } + return nil +} + // WorkspaceAgentDevcontainerStatus is the status of a devcontainer. type WorkspaceAgentDevcontainerStatus string diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index bb929c9ba2a04..36dd471712a3c 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -1,6 +1,7 @@ package workspacesdk import ( + "bytes" "context" "encoding/binary" "encoding/json" @@ -312,6 +313,74 @@ func (c *agentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgent return resp, json.NewDecoder(res.Body).Decode(&resp) } +// ImmortalStreams lists the immortal streams that are currently active in the workspace. +func (c *AgentConn) ImmortalStreams(ctx context.Context) ([]codersdk.ImmortalStream, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/immortal-stream", nil) + if err != nil { + return nil, xerrors.Errorf("do request: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, codersdk.ReadBodyAsError(res) + } + + var streams []codersdk.ImmortalStream + return streams, json.NewDecoder(res.Body).Decode(&streams) +} + +// CreateImmortalStream creates a new immortal stream to the specified port. +func (c *AgentConn) CreateImmortalStream(ctx context.Context, req codersdk.CreateImmortalStreamRequest) (codersdk.ImmortalStream, error) { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + // Note: We can't easily add logging here since AgentConn doesn't have a logger + // But we can add some debug info to the error messages + + reqBody, err := json.Marshal(req) + if err != nil { + return codersdk.ImmortalStream{}, xerrors.Errorf("marshal request: %w", err) + } + + res, err := c.apiRequest(ctx, http.MethodPost, "/api/v0/immortal-stream", bytes.NewReader(reqBody)) + if err != nil { + return codersdk.ImmortalStream{}, xerrors.Errorf("do request to agent /api/v0/immortal-stream: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusCreated { + bodyErr := codersdk.ReadBodyAsError(res) + return codersdk.ImmortalStream{}, xerrors.Errorf("agent responded with status %d: %w", res.StatusCode, bodyErr) + } + + var stream codersdk.ImmortalStream + err = json.NewDecoder(res.Body).Decode(&stream) + if err != nil { + return codersdk.ImmortalStream{}, xerrors.Errorf("decode response: %w", err) + } + return stream, nil +} + +// DeleteImmortalStream deletes an immortal stream by ID. +func (c *AgentConn) DeleteImmortalStream(ctx context.Context, streamID uuid.UUID) error { + ctx, span := tracing.StartSpan(ctx) + defer span.End() + + path := fmt.Sprintf("/api/v0/immortal-stream/%s", streamID) + res, err := c.apiRequest(ctx, http.MethodDelete, path, nil) + if err != nil { + return xerrors.Errorf("do request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != http.StatusNoContent { + return codersdk.ReadBodyAsError(res) + } + + return nil +} + // Netcheck returns a network check report from the workspace agent. func (c *agentConn) Netcheck(ctx context.Context) (healthsdk.AgentNetcheckReport, error) { ctx, span := tracing.StartSpan(ctx) From 8d3d98b0685609366097cef396e7e3992e5b5645 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 19 Aug 2025 18:16:42 +0000 Subject: [PATCH 2/5] WIP --- cli/ssh.go | 8 ++++---- coderd/workspaceagents.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/cli/ssh.go b/cli/ssh.go index 8ce2a0420f172..e299363481711 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -411,11 +411,11 @@ func (r *RootCmd) ssh() *serpent.Command { slog.F("agent_status", workspaceAgent.Status), slog.F("immortal_fallback_enabled", immortalFallback)) - shouldFallback := immortalFallback && (strings.Contains(err.Error(), "too many immortal streams") || + shouldFallback := immortalFallback && (strings.Contains(err.Error(), "Too many immortal streams") || strings.Contains(err.Error(), "The connection was refused")) if shouldFallback { - if strings.Contains(err.Error(), "too many immortal streams") { + if strings.Contains(err.Error(), "Too many immortal streams") { logger.Warn(ctx, "too many immortal streams, falling back to regular SSH connection", slog.F("max_streams", "32")) } else { @@ -528,11 +528,11 @@ func (r *RootCmd) ssh() *serpent.Command { slog.F("agent_status", workspaceAgent.Status), slog.F("immortal_fallback_enabled", immortalFallback)) - shouldFallback := immortalFallback && (strings.Contains(err.Error(), "too many immortal streams") || + shouldFallback := immortalFallback && (strings.Contains(err.Error(), "Too many immortal streams") || strings.Contains(err.Error(), "The connection was refused")) if shouldFallback { - if strings.Contains(err.Error(), "too many immortal streams") { + if strings.Contains(err.Error(), "Too many immortal streams") { logger.Warn(ctx, "too many immortal streams, falling back to regular SSH connection", slog.F("max_streams", "32")) } else { diff --git a/coderd/workspaceagents.go b/coderd/workspaceagents.go index 9da3c2fb85127..09d3cb421438f 100644 --- a/coderd/workspaceagents.go +++ b/coderd/workspaceagents.go @@ -915,7 +915,7 @@ func (api *API) workspaceAgentCreateImmortalStream(rw http.ResponseWriter, r *ht stream, err := agentConn.CreateImmortalStream(ctx, req) if err != nil { // Check for specific error types from the agent - if strings.Contains(err.Error(), "too many immortal streams") { + if strings.Contains(err.Error(), "Too many Immortal Streams") { httpapi.Write(ctx, rw, http.StatusServiceUnavailable, codersdk.Response{ Message: "Too many immortal streams.", }) From 630145c179e3b5471af8e2d06790c2f4f7ebfd76 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 19 Aug 2025 18:33:32 +0000 Subject: [PATCH 3/5] WIP --- cli/portforward.go | 100 +++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 92 insertions(+), 8 deletions(-) diff --git a/cli/portforward.go b/cli/portforward.go index d96a0d697d289..d29a17e8e5b6c 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -13,6 +13,7 @@ import ( "sync" "syscall" + "github.com/google/uuid" "golang.org/x/xerrors" "cdr.dev/slog" @@ -152,7 +153,7 @@ func (r *RootCmd) portForward() *serpent.Command { // first, opportunistically try to listen on IPv6 spec6 := spec spec6.listenHost = ipv6Loopback - l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger) + l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger, immortal, immortalFallback, client, workspaceAgent.ID) if err6 != nil { logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6)) } else { @@ -160,7 +161,7 @@ func (r *RootCmd) portForward() *serpent.Command { } spec.listenHost = ipv4Loopback } - l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger) + l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger, immortal, immortalFallback, client, workspaceAgent.ID) if err != nil { logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err)) return err @@ -242,6 +243,10 @@ func listenAndPortForward( wg *sync.WaitGroup, spec portForwardSpec, logger slog.Logger, + immortal bool, + immortalFallback bool, + client *codersdk.Client, + agentID uuid.UUID, ) (net.Listener, error) { logger = logger.With( slog.F("network", spec.network), @@ -281,17 +286,96 @@ func listenAndPortForward( go func(netConn net.Conn) { defer netConn.Close() - remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress) - if err != nil { - _, _ = fmt.Fprintf(inv.Stderr, - "Failed to dial '%s://%s' in workspace: %s\n", - spec.network, dialAddress, err) - return + + var remoteConn net.Conn + var immortalStreamClient *immortalStreamClient + var streamID *uuid.UUID + + // Only use immortal streams for TCP connections + if immortal && spec.network == "tcp" { + // Create immortal stream client + immortalStreamClient = newImmortalStreamClient(client, agentID, logger) + + // Create immortal stream to the target port + stream, err := immortalStreamClient.createStream(ctx, int(spec.dialPort)) + if err != nil { + logger.Error(ctx, "failed to create immortal stream for port forward", + slog.Error(err), + slog.F("agent_id", agentID), + slog.F("target_port", spec.dialPort), + slog.F("immortal_fallback_enabled", immortalFallback)) + + shouldFallback := immortalFallback && (strings.Contains(err.Error(), "Too many immortal streams") || + strings.Contains(err.Error(), "The connection was refused")) + + if shouldFallback { + if strings.Contains(err.Error(), "Too many immortal streams") { + logger.Warn(ctx, "too many immortal streams, falling back to regular port forward", + slog.F("max_streams", "32"), + slog.F("target_port", spec.dialPort)) + } else { + logger.Warn(ctx, "service not available, falling back to regular port forward", + slog.F("reason", "connection_refused"), + slog.F("target_port", spec.dialPort)) + } + logger.Debug(ctx, "attempting fallback to regular port forward") + remoteConn, err = conn.DialContext(ctx, spec.network, dialAddress) + if err != nil { + logger.Error(ctx, "fallback port forward also failed", slog.Error(err)) + _, _ = fmt.Fprintf(inv.Stderr, + "Failed to dial '%s://%s' in workspace: %s\n", + spec.network, dialAddress, err) + return + } + logger.Debug(ctx, "successfully connected via regular port forward fallback") + } else { + _, _ = fmt.Fprintf(inv.Stderr, + "Failed to create immortal stream for '%s://%s' in workspace: %s\n", + spec.network, dialAddress, err) + return + } + } else { + streamID = &stream.ID + logger.Debug(ctx, "created immortal stream for port forward", + slog.F("stream_name", stream.Name), + slog.F("stream_id", stream.ID), + slog.F("target_port", spec.dialPort)) + + // Connect to the immortal stream via WebSocket + remoteConn, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) + if err != nil { + // Clean up the stream if connection fails + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + _, _ = fmt.Fprintf(inv.Stderr, + "Failed to connect to immortal stream for '%s://%s' in workspace: %s\n", + spec.network, dialAddress, err) + return + } + } + } else { + // Use regular connection for UDP or when immortal is disabled + remoteConn, err = conn.DialContext(ctx, spec.network, dialAddress) + if err != nil { + _, _ = fmt.Fprintf(inv.Stderr, + "Failed to dial '%s://%s' in workspace: %s\n", + spec.network, dialAddress, err) + return + } } + defer remoteConn.Close() logger.Debug(ctx, "dialed remote", slog.F("remote_addr", netConn.RemoteAddr())) + // Set up cleanup for immortal stream + if immortalStreamClient != nil && streamID != nil { + defer func() { + if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { + logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err)) + } + }() + } + agentssh.Bicopy(ctx, netConn, remoteConn) logger.Debug(ctx, "connection closing", slog.F("remote_addr", netConn.RemoteAddr())) From c8aa4e9816c74c9f82cc2ded67b5fa6d8ac1541e Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Wed, 20 Aug 2025 11:50:52 +0000 Subject: [PATCH 4/5] WIP --- agent/agent.go | 97 ++++++++++++++++++++++++-------------- cli/ssh.go | 124 +++++++++++++++++++++++++------------------------ 2 files changed, 125 insertions(+), 96 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 4cefcfa9f8616..f8ad5eb73f1a9 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -70,44 +70,13 @@ const ( EnvProcOOMScore = "CODER_PROC_OOM_SCORE" ) -// agentImmortalDialer is a custom dialer for immortal streams that can -// connect to the agent's own services via tailnet addresses. +// agentImmortalDialer wraps the standard dialer for immortal streams. +// Agent services are available on both tailnet and localhost interfaces. type agentImmortalDialer struct { - agent *agent standardDialer *net.Dialer } func (d *agentImmortalDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - host, portStr, err := net.SplitHostPort(address) - if err != nil { - return nil, xerrors.Errorf("split host port %q: %w", address, err) - } - - port, err := strconv.Atoi(portStr) - if err != nil { - return nil, xerrors.Errorf("parse port %q: %w", portStr, err) - } - - // Check if this is a connection to one of the agent's own services - isLocalhost := host == "localhost" || host == "127.0.0.1" || host == "::1" - isAgentPort := port == int(workspacesdk.AgentSSHPort) || port == int(workspacesdk.AgentHTTPAPIServerPort) || - port == int(workspacesdk.AgentReconnectingPTYPort) || port == int(workspacesdk.AgentSpeedtestPort) - - if isLocalhost && isAgentPort { - // Get the agent ID from the current manifest - manifest := d.agent.manifest.Load() - if manifest == nil || manifest.AgentID == uuid.Nil { - // Fallback to standard dialing if no manifest available yet - return d.standardDialer.DialContext(ctx, network, address) - } - - // Connect to the agent's own tailnet address instead of localhost - agentAddr := tailnet.TailscaleServicePrefix.AddrFromUUID(manifest.AgentID) - agentAddress := net.JoinHostPort(agentAddr.String(), portStr) - return d.standardDialer.DialContext(ctx, network, agentAddress) - } - - // For other addresses, use standard dialing return d.standardDialer.DialContext(ctx, network, address) } @@ -392,10 +361,8 @@ func (a *agent) init() { a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...) - // Initialize immortal streams manager with a custom dialer - // that can connect to the agent's own services + // Initialize immortal streams manager immortalDialer := &agentImmortalDialer{ - agent: a, standardDialer: &net.Dialer{}, } a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), immortalDialer) @@ -1531,6 +1498,7 @@ func (a *agent) createTailnet( } for _, port := range []int{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort} { + // Listen on tailnet interface for external connections sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(port)) if err != nil { return nil, xerrors.Errorf("listen on the ssh port (%v): %w", port, err) @@ -1546,6 +1514,25 @@ func (a *agent) createTailnet( }); err != nil { return nil, err } + + // Also listen on localhost for immortal streams (only for SSH port 1) + if port == workspacesdk.AgentSSHPort { + localhostListener, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(port)) + if err != nil { + return nil, xerrors.Errorf("listen on localhost ssh port (%v): %w", port, err) + } + // nolint:revive // We do want to run the deferred functions when createTailnet returns. + defer func() { + if err != nil { + _ = localhostListener.Close() + } + }() + if err = a.trackGoroutine(func() { + _ = a.sshServer.Serve(localhostListener) + }); err != nil { + return nil, err + } + } } reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentReconnectingPTYPort)) @@ -1616,6 +1603,7 @@ func (a *agent) createTailnet( return nil, err } + // Listen on tailnet interface for external connections apiListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentHTTPAPIServerPort)) if err != nil { return nil, xerrors.Errorf("api listener: %w", err) @@ -1652,6 +1640,43 @@ func (a *agent) createTailnet( return nil, err } + // Also listen on localhost for immortal streams WebSocket connections + localhostAPIListener, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(workspacesdk.AgentHTTPAPIServerPort)) + if err != nil { + return nil, xerrors.Errorf("localhost api listener: %w", err) + } + defer func() { + if err != nil { + _ = localhostAPIListener.Close() + } + }() + if err = a.trackGoroutine(func() { + defer localhostAPIListener.Close() + apiHandler := a.apiHandler() + server := &http.Server{ + BaseContext: func(net.Listener) context.Context { return ctx }, + Handler: apiHandler, + ReadTimeout: 20 * time.Second, + ReadHeaderTimeout: 20 * time.Second, + WriteTimeout: 20 * time.Second, + ErrorLog: slog.Stdlib(ctx, a.logger.Named("http_api_server_localhost"), slog.LevelInfo), + } + go func() { + select { + case <-ctx.Done(): + case <-a.hardCtx.Done(): + } + _ = server.Close() + }() + + apiServErr := server.Serve(localhostAPIListener) + if apiServErr != nil && !xerrors.Is(apiServErr, http.ErrServerClosed) && !strings.Contains(apiServErr.Error(), "use of closed network connection") { + a.logger.Critical(ctx, "serve localhost HTTP API server", slog.Error(apiServErr)) + } + }); err != nil { + return nil, err + } + return network, nil } diff --git a/cli/ssh.go b/cli/ssh.go index e299363481711..478473d294ee3 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -440,8 +440,10 @@ func (r *RootCmd) ssh() *serpent.Command { // Connect to the immortal stream via WebSocket rawSSH, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) if err != nil { - // Clean up the stream if connection fails - _ = immortalStreamClient.deleteStream(ctx, stream.ID) + // Only clean up the stream if it's a permanent failure + if !isNetworkError(err) { + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + } return xerrors.Errorf("connect to immortal stream: %w", err) } } @@ -481,12 +483,17 @@ func (r *RootCmd) ssh() *serpent.Command { } } - // Set up cleanup for immortal stream + // Set up signal-based cleanup for immortal stream + // Only delete on explicit user termination (SIGINT, SIGTERM), not network errors if immortalStreamClient != nil && streamID != nil { - defer func() { - if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { - logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err)) - } + // Create a signal-only context for cleanup + signalCtx, signalStop := inv.SignalNotifyContext(context.Background(), StopSignals...) + defer signalStop() + + go func() { + <-signalCtx.Done() + // User sent termination signal - clean up the stream + _ = immortalStreamClient.deleteStream(context.Background(), *streamID) }() } @@ -494,12 +501,7 @@ func (r *RootCmd) ssh() *serpent.Command { go func() { defer wg.Done() watchAndClose(ctx, func() error { - // Clean up immortal stream on termination - if immortalStreamClient != nil && streamID != nil { - if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { - logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err)) - } - } + // Don't delete immortal stream here - let signal handler do it stack.close(xerrors.New("watchAndClose")) return nil }, logger, client, workspace, errCh) @@ -557,8 +559,10 @@ func (r *RootCmd) ssh() *serpent.Command { // Connect to the immortal stream and create SSH client rawConn, err := connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) if err != nil { - // Clean up the stream if connection fails - _ = immortalStreamClient.deleteStream(ctx, stream.ID) + // Only clean up the stream if it's a permanent failure + if !isNetworkError(err) { + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + } return xerrors.Errorf("connect to immortal stream: %w", err) } @@ -569,7 +573,10 @@ func (r *RootCmd) ssh() *serpent.Command { }) if err != nil { rawConn.Close() - _ = immortalStreamClient.deleteStream(ctx, stream.ID) + // Only clean up the stream if it's a permanent failure + if !isNetworkError(err) { + _ = immortalStreamClient.deleteStream(ctx, stream.ID) + } return xerrors.Errorf("ssh handshake over immortal stream: %w", err) } @@ -603,12 +610,17 @@ func (r *RootCmd) ssh() *serpent.Command { } } - // Set up cleanup for immortal stream in regular SSH mode + // Set up signal-based cleanup for immortal stream + // Only delete on explicit user termination (SIGINT, SIGTERM), not network errors if immortalStreamClient != nil && streamID != nil { - defer func() { - if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { - logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err)) - } + // Create a signal-only context for cleanup + signalCtx, signalStop := inv.SignalNotifyContext(context.Background(), StopSignals...) + defer signalStop() + + go func() { + <-signalCtx.Done() + // User sent termination signal - clean up the stream + _ = immortalStreamClient.deleteStream(context.Background(), *streamID) }() } @@ -618,12 +630,7 @@ func (r *RootCmd) ssh() *serpent.Command { watchAndClose( ctx, func() error { - // Clean up immortal stream on termination - if immortalStreamClient != nil && streamID != nil { - if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil { - logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err)) - } - } + // Don't delete immortal stream here - let signal handler do it stack.close(xerrors.New("watchAndClose")) return nil }, @@ -923,66 +930,63 @@ func (r *RootCmd) ssh() *serpent.Command { return cmd } -// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket and returns a net.Conn +// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket +// The immortal stream infrastructure handles reconnection automatically func connectToImmortalStreamWebSocket(ctx context.Context, agentConn *workspacesdk.AgentConn, streamID uuid.UUID, logger slog.Logger) (net.Conn, error) { // Build the target address for the agent's HTTP API server - // We'll let the WebSocket dialer handle the actual connection through the agent apiServerAddr := fmt.Sprintf("127.0.0.1:%d", workspacesdk.AgentHTTPAPIServerPort) wsURL := fmt.Sprintf("ws://%s/api/v0/immortal-stream/%s", apiServerAddr, streamID) // Create WebSocket connection using the agent's tailnet connection - // The key is to use a custom dialer that routes through the agent connection dialOptions := &websocket.DialOptions{ HTTPClient: &http.Client{ Transport: &http.Transport{ DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) { - // Route all connections through the agent connection - // The agent connection will handle routing to the correct internal address - - conn, err := agentConn.DialContext(dialCtx, network, addr) - if err != nil { - return nil, err - } - - return conn, nil + return agentConn.DialContext(dialCtx, network, addr) }, }, }, - // Disable compression for raw TCP data CompressionMode: websocket.CompressionDisabled, } // Connect to the WebSocket endpoint - conn, res, err := websocket.Dial(ctx, wsURL, dialOptions) + conn, _, err := websocket.Dial(ctx, wsURL, dialOptions) if err != nil { - if res != nil { - logger.Error(ctx, "WebSocket dial failed", - slog.F("stream_id", streamID), - slog.F("websocket_url", wsURL), - slog.F("status", res.StatusCode), - slog.F("status_text", res.Status), - slog.Error(err)) - } else { - logger.Error(ctx, "WebSocket dial failed (no response)", - slog.F("stream_id", streamID), - slog.F("websocket_url", wsURL), - slog.Error(err)) - } return nil, xerrors.Errorf("dial immortal stream WebSocket: %w", err) } - logger.Info(ctx, "successfully connected to immortal stream WebSocket", - slog.F("stream_id", streamID)) - // Convert WebSocket to net.Conn for SSH usage - // Use MessageBinary for raw TCP data transport + // The immortal stream's BackedPipe handles reconnection automatically netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary) - logger.Debug(ctx, "converted WebSocket to net.Conn for SSH usage") - return netConn, nil } +// isNetworkError checks if an error is a temporary network error +func isNetworkError(err error) bool { + if err == nil { + return false + } + + errStr := err.Error() + networkErrors := []string{ + "connection refused", + "network is unreachable", + "connection reset", + "broken pipe", + "timeout", + "no route to host", + } + + for _, netErr := range networkErrors { + if strings.Contains(errStr, netErr) { + return true + } + } + + return false +} + // findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it // corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or // vscode-coder--myusername--myworkspace). From 5e5d5b5850d4a6762e456aa3f05841eba738b79d Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Mon, 25 Aug 2025 20:31:02 +0000 Subject: [PATCH 5/5] WIP --- cli/immortal_reconnecting_conn.go | 558 ++++++++++++++++++ cli/immortalstreams.go | 4 +- cli/ssh.go | 69 ++- codersdk/workspacesdk/agentconn.go | 11 +- .../agentconnmock/agentconnmock.go | 44 ++ 5 files changed, 646 insertions(+), 40 deletions(-) create mode 100644 cli/immortal_reconnecting_conn.go diff --git a/cli/immortal_reconnecting_conn.go b/cli/immortal_reconnecting_conn.go new file mode 100644 index 0000000000000..bfb8d474e1fce --- /dev/null +++ b/cli/immortal_reconnecting_conn.go @@ -0,0 +1,558 @@ +package cli + +import ( + "context" + "fmt" + "net" + "net/http" + "strconv" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/sync/singleflight" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" + "github.com/coder/websocket" +) + +// immortalReconnectingConn is a net.Conn that talks to an agent Immortal Stream +// endpoint and transparently reconnects on failures. It preserves read +// sequence state via the X-Coder-Immortal-Stream-Sequence-Num header so the +// server can replay any missed bytes to the client. Writes will block across +// reconnects and resume once a new connection is established. +// +// Note: Without an explicit server-to-client acknowledgement of how many bytes +// of client->server data were consumed, we avoid attempting to replay writes +// from the client. Instead, Write blocks until a new connection is established +// and then continues writing new data. This preserves the SSH session transport +// and prevents premature termination. +type immortalReconnectingConn struct { + ctx context.Context + cancel context.CancelFunc + + agentConn workspacesdk.AgentConn + streamID uuid.UUID + logger slog.Logger + + mu sync.Mutex + ws *websocket.Conn + nc net.Conn + closed bool + readerSN uint64 // total bytes read by this client + + // cancel the per-connection keepalive loop + keepaliveCancel context.CancelFunc + + // Deduplicate concurrent reconnect attempts + sf singleflight.Group + + // start the background reconnect supervisor only once + bgOnce sync.Once + + // Optional: called when the server indicates the stream ID is invalid + // and a new stream should be created. Returns a replacement stream ID. + refreshStreamID func(context.Context) (uuid.UUID, error) +} + +// newImmortalReconnectingConn constructs a reconnecting connection and dials +// the initial websocket. Subsequent reads/writes will reconnect on demand. +func newImmortalReconnectingConn(parent context.Context, agentConn workspacesdk.AgentConn, streamID uuid.UUID, logger slog.Logger, refresh func(context.Context) (uuid.UUID, error)) (net.Conn, error) { + ctx, cancel := context.WithCancel(parent) + + // Add connection ID for better logging + connID := uuid.New() + logger = logger.With(slog.F("conn_id", connID), slog.F("stream_id", streamID)) + + c := &immortalReconnectingConn{ + ctx: ctx, + cancel: cancel, + agentConn: agentConn, + streamID: streamID, + logger: logger, + refreshStreamID: refresh, + } + + c.logger.Debug(context.Background(), "creating new immortal reconnecting connection") + + if err := c.ensureConnected(); err != nil { + cancel() + return nil, xerrors.Errorf("initial connection failed: %w", err) + } + + c.logger.Debug(context.Background(), "immortal reconnecting connection created successfully") + // Ensure we always have an out-of-band retry loop so that reconnects + // continue even when no reader/writer is active. + c.startReconnectSupervisor() + return c, nil +} + +func (c *immortalReconnectingConn) Read(p []byte) (int, error) { + for { + c.mu.Lock() + nc := c.nc + closed := c.closed + c.mu.Unlock() + if closed { + c.logger.Debug(context.Background(), "read called on closed connection") + return 0, net.ErrClosed + } + + if nc == nil { + c.logger.Debug(context.Background(), "read called on nil connection, attempting reconnect") + if err := c.reconnect(); err != nil { + if c.ctx.Err() != nil { + c.logger.Debug(context.Background(), "read reconnect failed due to context cancellation", slog.Error(err)) + return 0, c.ctx.Err() + } + c.logger.Error(context.Background(), "read reconnect failed, will retry", slog.Error(err)) + // Brief backoff to avoid hot loop + select { + case <-time.After(200 * time.Millisecond): + case <-c.ctx.Done(): + return 0, c.ctx.Err() + } + continue + } + continue + } + + n, err := nc.Read(p) + if n > 0 { + c.mu.Lock() + c.readerSN += uint64(n) + c.mu.Unlock() + c.logger.Debug(context.Background(), "read successful", slog.F("bytes", n), slog.F("total_read", c.readerSN)) + return n, nil + } + if err == nil { + // zero bytes without error, try again + continue + } + + // Read error: trigger reconnect loop and retry + c.logger.Debug(context.Background(), "immortal read error, reconnecting", slog.Error(err)) + _ = c.reconnect() + // Loop to retry read on new connection + } +} + +func (c *immortalReconnectingConn) Write(p []byte) (int, error) { + writtenTotal := 0 + for writtenTotal < len(p) { + c.mu.Lock() + nc := c.nc + closed := c.closed + c.mu.Unlock() + if closed { + c.logger.Debug(context.Background(), "write called on closed connection") + return writtenTotal, net.ErrClosed + } + if nc == nil { + c.logger.Debug(context.Background(), "write called on nil connection, attempting reconnect") + if err := c.reconnect(); err != nil { + if c.ctx.Err() != nil { + c.logger.Debug(context.Background(), "write reconnect failed due to context cancellation", slog.Error(err)) + return writtenTotal, c.ctx.Err() + } + c.logger.Error(context.Background(), "write reconnect failed, will retry", slog.Error(err)) + // Backoff before reattempting + select { + case <-time.After(200 * time.Millisecond): + case <-c.ctx.Done(): + return writtenTotal, c.ctx.Err() + } + continue + } + continue + } + + n, err := nc.Write(p[writtenTotal:]) + if n > 0 { + writtenTotal += n + c.logger.Debug(context.Background(), "write partial success", slog.F("bytes", n), slog.F("total_written", writtenTotal), slog.F("remaining", len(p)-writtenTotal)) + } + if err == nil { + continue + } + // Write error: reconnect and retry remaining bytes + c.logger.Debug(context.Background(), "immortal write error, reconnecting", slog.Error(err)) + if rerr := c.reconnect(); rerr != nil { + if c.ctx.Err() != nil { + return writtenTotal, c.ctx.Err() + } + c.logger.Error(context.Background(), "write reconnect failed, will retry", slog.Error(rerr)) + // Brief backoff then try again + select { + case <-time.After(200 * time.Millisecond): + case <-c.ctx.Done(): + return writtenTotal, c.ctx.Err() + } + } + } + c.logger.Debug(context.Background(), "write completed successfully", slog.F("total_bytes", writtenTotal)) + return writtenTotal, nil +} + +func (c *immortalReconnectingConn) Close() error { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil + } + c.closed = true + ws := c.ws + nc := c.nc + kaCancel := c.keepaliveCancel + c.ws = nil + c.nc = nil + c.keepaliveCancel = nil + c.mu.Unlock() + + c.logger.Debug(context.Background(), "closing immortal reconnecting connection") + + c.cancel() + if kaCancel != nil { + kaCancel() + } + + var firstErr error + if nc != nil { + if err := nc.Close(); err != nil && firstErr == nil { + firstErr = err + } + } + if ws != nil { + if err := ws.Close(websocket.StatusNormalClosure, ""); err != nil && firstErr == nil { + firstErr = err + } + } + + if firstErr != nil { + c.logger.Error(context.Background(), "error during connection close", slog.Error(firstErr)) + } else { + c.logger.Debug(context.Background(), "immortal reconnecting connection closed successfully") + } + + return firstErr +} + +func (c *immortalReconnectingConn) LocalAddr() net.Addr { + c.mu.Lock() + defer c.mu.Unlock() + if c.nc != nil { + return c.nc.LocalAddr() + } + // best-effort zero addr + return nil +} + +func (c *immortalReconnectingConn) RemoteAddr() net.Addr { + c.mu.Lock() + defer c.mu.Unlock() + if c.nc != nil { + return c.nc.RemoteAddr() + } + return nil +} + +func (c *immortalReconnectingConn) SetDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.nc != nil { + return c.nc.SetDeadline(t) + } + return nil +} + +func (c *immortalReconnectingConn) SetReadDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.nc != nil { + return c.nc.SetReadDeadline(t) + } + return nil +} + +func (c *immortalReconnectingConn) SetWriteDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.nc != nil { + return c.nc.SetWriteDeadline(t) + } + return nil +} + +// ensureConnected dials the websocket if not currently connected. +func (c *immortalReconnectingConn) ensureConnected() error { + c.mu.Lock() + already := c.nc != nil + c.mu.Unlock() + if already { + return nil + } + _, err, _ := c.sf.Do("reconnect", func() (any, error) { + return nil, c.connectOnce() + }) + return err +} + +// reconnect forces a reconnect regardless of current state. +func (c *immortalReconnectingConn) reconnect() error { + c.logger.Debug(context.Background(), "starting reconnection process") + + _, err, _ := c.sf.Do("reconnect", func() (any, error) { + // Close any existing connection outside of lock to unblock reader/writer + c.mu.Lock() + // stop previous keepalive loop if any + if c.keepaliveCancel != nil { + c.logger.Debug(context.Background(), "canceling previous keepalive loop") + c.keepaliveCancel() + c.keepaliveCancel = nil + } + ws := c.ws + nc := c.nc + c.ws = nil + c.nc = nil + c.mu.Unlock() + + if nc != nil { + c.logger.Debug(context.Background(), "closing previous net.Conn") + _ = nc.Close() + } + if ws != nil { + c.logger.Debug(context.Background(), "closing previous websocket") + _ = ws.Close(websocket.StatusNormalClosure, "reconnect") + } + + c.logger.Debug(context.Background(), "attempting new connection") + return nil, c.connectOnce() + }) + + if err != nil { + c.logger.Error(context.Background(), "reconnection failed", slog.Error(err)) + } else { + c.logger.Debug(context.Background(), "reconnection completed successfully") + } + + // Kick the supervisor so it continues retrying if we failed here. + if err != nil { + c.startReconnectSupervisor() + } + return err +} + +func (c *immortalReconnectingConn) connectOnce() error { + if c.ctx.Err() != nil { + c.logger.Debug(context.Background(), "connectOnce called with canceled context", slog.Error(c.ctx.Err())) + return c.ctx.Err() + } + + // Build the target address for the agent's HTTP API server + apiServerAddr := fmt.Sprintf("127.0.0.1:%d", workspacesdk.AgentHTTPAPIServerPort) + wsURL := fmt.Sprintf("ws://%s/api/v0/immortal-stream/%s", apiServerAddr, c.streamID) + + // Include current reader sequence so the server can replay any missed bytes + c.mu.Lock() + readerSeq := c.readerSN + c.mu.Unlock() + + c.logger.Debug(context.Background(), "dialing websocket", + slog.F("url", wsURL), + slog.F("reader_seq", readerSeq)) + + dialOptions := &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: &http.Transport{ + DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) { + c.logger.Debug(context.Background(), "dialing network connection", slog.F("network", network), slog.F("addr", addr)) + return c.agentConn.DialContext(dialCtx, network, addr) + }, + }, + }, + HTTPHeader: http.Header{ + codersdk.HeaderImmortalStreamSequenceNum: []string{strconv.FormatUint(readerSeq, 10)}, + }, + CompressionMode: websocket.CompressionDisabled, + } + + // Use a per-attempt dial timeout to avoid indefinite hangs on half-open + // connections or blackholed networks during reconnect attempts. + dialCtx, cancel := context.WithTimeout(c.ctx, 10*time.Second) + defer cancel() + + ws, resp, err := websocket.Dial(dialCtx, wsURL, dialOptions) + if err != nil { + // If we received an HTTP response, inspect status codes that indicate + // the stream ID is no longer valid and attempt to refresh it. + if resp != nil { + status := resp.StatusCode + _ = resp.Body.Close() + if (status == http.StatusNotFound || status == http.StatusGone || status == http.StatusBadRequest) && c.refreshStreamID != nil { + c.logger.Warn(context.Background(), "immortal stream appears invalid; attempting to refresh stream id", slog.F("status", status)) + // Try to obtain a new stream ID and retry once immediately. + newID, rerr := c.refreshStreamID(c.ctx) + if rerr == nil { + c.mu.Lock() + c.streamID = newID + c.mu.Unlock() + // Rebuild URL with new ID and redial within same attempt context. + wsURL = fmt.Sprintf("ws://%s/api/v0/immortal-stream/%s", apiServerAddr, newID) + c.logger.Info(context.Background(), "retrying websocket dial with refreshed stream id", slog.F("url", wsURL)) + ws2, resp2, err2 := websocket.Dial(dialCtx, wsURL, dialOptions) + if err2 == nil { + ws = ws2 + // Ensure any intermediate resp2 is closed by NetConn when closed; nothing to do here. + goto DIAL_SUCCESS + } + if resp2 != nil && resp2.Body != nil { + _ = resp2.Body.Close() + } + c.logger.Error(context.Background(), "websocket dial failed after refresh", slog.Error(err2)) + return xerrors.Errorf("dial immortal stream websocket (after refresh): %w", err2) + } + c.logger.Error(context.Background(), "failed to refresh immortal stream id", slog.Error(rerr)) + } + } + c.logger.Error(context.Background(), "websocket dial failed", slog.Error(err), slog.F("url", wsURL)) + return xerrors.Errorf("dial immortal stream websocket: %w", err) + } +DIAL_SUCCESS: + c.logger.Debug(context.Background(), "websocket dial successful") + + // Convert WebSocket to net.Conn for binary transport + // Tie lifecycle to our context so reads/writes unblock on shutdown + nc := websocket.NetConn(c.ctx, ws, websocket.MessageBinary) + + // swap in new connection and start keepalive + c.mu.Lock() + // stop previous keepalive loop if any (defensive if connectOnce is called directly) + if c.keepaliveCancel != nil { + c.logger.Debug(context.Background(), "canceling existing keepalive loop during connection swap") + c.keepaliveCancel() + c.keepaliveCancel = nil + } + c.ws = ws + c.nc = nc + kaCtx, kaCancel := context.WithCancel(c.ctx) + c.keepaliveCancel = kaCancel + c.mu.Unlock() + + // start ping keepalive + go c.keepaliveLoop(kaCtx, ws) + + c.logger.Debug(context.Background(), "connected to immortal stream", slog.F("stream_id", c.streamID)) + return nil +} + +// keepaliveLoop periodically pings the websocket to detect half-open connections. +// On ping failure it triggers a reconnect. +func (c *immortalReconnectingConn) keepaliveLoop(ctx context.Context, ws *websocket.Conn) { + c.logger.Debug(context.Background(), "starting keepalive loop") + + t := time.NewTicker(1 * time.Second) + defer t.Stop() + + pingCount := 0 + for { + select { + case <-ctx.Done(): + c.logger.Debug(context.Background(), "keepalive loop context canceled", slog.Error(ctx.Err())) + return + case <-t.C: + pingCount++ + pctx, cancel := context.WithTimeout(ctx, time.Second) + err := ws.Ping(pctx) + cancel() + if err != nil { + c.logger.Debug(context.Background(), "immortal ping failed, reconnecting", + slog.Error(err), + slog.F("ping_count", pingCount)) + // Best effort: trigger reconnect to replace dead socket. + // Don't return - continue monitoring for future failures + _ = c.reconnect() + // Continue the loop to monitor the new connection + } else if pingCount%10 == 0 { // Log every 10th successful ping to avoid spam + c.logger.Debug(context.Background(), "keepalive ping successful", slog.F("ping_count", pingCount)) + } + } + } +} + +// startReconnectSupervisor launches a background loop that ensures reconnect attempts +// continue indefinitely while the context is alive, even if there are no active +// Read/Write calls to trigger reconnects. +func (c *immortalReconnectingConn) startReconnectSupervisor() { + c.bgOnce.Do(func() { + go func() { + // Basic capped exponential backoff. + backoff := 200 * time.Millisecond + const maxBackoff = 5 * time.Second + failureCount := 0 + for { + select { + case <-c.ctx.Done(): + return + default: + } + + c.mu.Lock() + hasConn := c.nc != nil + ws := c.ws + c.mu.Unlock() + + if hasConn { + // Reset backoff when we have a healthy connection. + backoff = 200 * time.Millisecond + // Actively ping in case the keepalive loop has stopped for any reason. + if ws != nil { + pctx, cancel := context.WithTimeout(c.ctx, 2*time.Second) + if err := ws.Ping(pctx); err != nil { + cancel() + // Escalate visibility so users see continued retries. + c.logger.Error(context.Background(), "supervisor ping failed, forcing reconnect", slog.Error(err)) + _ = c.reconnect() + } else { + cancel() + } + } + // Poll sparsely to detect transitions without busy looping. + select { + case <-time.After(1 * time.Second): + case <-c.ctx.Done(): + return + } + continue + } + + // No connection: attempt a reconnect. + if err := c.ensureConnected(); err != nil { + failureCount++ + // Log as error so it's visible that we're still retrying. + c.logger.Error(context.Background(), "background reconnect attempt failed", slog.Error(err), slog.F("attempt", failureCount), slog.F("backoff", backoff.String())) + // Backoff and retry until success or context cancel. + select { + case <-time.After(backoff): + case <-c.ctx.Done(): + return + } + if backoff < maxBackoff { + backoff *= 2 + if backoff > maxBackoff { + backoff = maxBackoff + } + } + continue + } + + // Success: a keepalive loop is created by connectOnce; loop will + // continue monitoring in case the connection drops again. + failureCount = 0 + backoff = 200 * time.Millisecond + } + }() + }) +} diff --git a/cli/immortalstreams.go b/cli/immortalstreams.go index 7dc3e0300d7ab..709dd5d5cce21 100644 --- a/cli/immortalstreams.go +++ b/cli/immortalstreams.go @@ -86,7 +86,7 @@ func (r *RootCmd) immortalStreamListCmd() *serpent.Command { ctx := inv.Context() workspaceName := inv.Args[0] - workspace, workspaceAgent, _, err := getWorkspaceAndAgent(ctx, inv, client, false, workspaceName) + workspace, workspaceAgent, _, err := GetWorkspaceAndAgent(ctx, inv, client, false, workspaceName) if err != nil { return err } @@ -131,7 +131,7 @@ func (r *RootCmd) immortalStreamDeleteCmd() *serpent.Command { workspaceName := inv.Args[0] streamName := inv.Args[1] - workspace, workspaceAgent, _, err := getWorkspaceAndAgent(ctx, inv, client, false, workspaceName) + workspace, workspaceAgent, _, err := GetWorkspaceAndAgent(ctx, inv, client, false, workspaceName) if err != nil { return err } diff --git a/cli/ssh.go b/cli/ssh.go index 478473d294ee3..13c8a0dd00e6f 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -48,7 +48,6 @@ import ( "github.com/coder/quartz" "github.com/coder/retry" "github.com/coder/serpent" - "github.com/coder/websocket" ) const ( @@ -437,8 +436,23 @@ func (r *RootCmd) ssh() *serpent.Command { streamID = &stream.ID logger.Info(ctx, "created immortal stream for SSH", slog.F("stream_name", stream.Name), slog.F("stream_id", stream.ID)) - // Connect to the immortal stream via WebSocket - rawSSH, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) + // Connect to the immortal stream via WebSocket + // Provide a refresh callback to recreate the stream if the agent indicates it's invalid + refresh := func(rctx context.Context) (uuid.UUID, error) { + // Try to create a replacement stream to SSH port 1 + s, rerr := immortalStreamClient.createStream(rctx, 1) + if rerr != nil { + logger.Error(rctx, "failed to refresh immortal stream id", slog.Error(rerr)) + return uuid.UUID{}, rerr + } + logger.Info(rctx, "refreshed immortal stream id", slog.F("old", stream.ID), slog.F("new", s.ID)) + // Note: we intentionally do not delete the old stream; if the agent restarted it no longer exists. + // If it still exists, users can gc via list/delete commands. + // Update outer streamID for signal-based cleanup routing + stream.ID = s.ID + return s.ID, nil + } + rawSSH, err = newImmortalReconnectingConn(ctx, conn, stream.ID, logger, refresh) if err != nil { // Only clean up the stream if it's a permanent failure if !isNetworkError(err) { @@ -556,8 +570,18 @@ func (r *RootCmd) ssh() *serpent.Command { streamID = &stream.ID logger.Info(ctx, "created immortal stream for SSH", slog.F("stream_name", stream.Name), slog.F("stream_id", stream.ID)) - // Connect to the immortal stream and create SSH client - rawConn, err := connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger) + // Connect to the immortal stream and create SSH client + refresh := func(rctx context.Context) (uuid.UUID, error) { + s, rerr := immortalStreamClient.createStream(rctx, 1) + if rerr != nil { + logger.Error(rctx, "failed to refresh immortal stream id", slog.Error(rerr)) + return uuid.UUID{}, rerr + } + logger.Info(rctx, "refreshed immortal stream id", slog.F("old", stream.ID), slog.F("new", s.ID)) + stream.ID = s.ID + return s.ID, nil + } + rawConn, err := newImmortalReconnectingConn(ctx, conn, stream.ID, logger, refresh) if err != nil { // Only clean up the stream if it's a permanent failure if !isNetworkError(err) { @@ -930,36 +954,11 @@ func (r *RootCmd) ssh() *serpent.Command { return cmd } -// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket -// The immortal stream infrastructure handles reconnection automatically -func connectToImmortalStreamWebSocket(ctx context.Context, agentConn *workspacesdk.AgentConn, streamID uuid.UUID, logger slog.Logger) (net.Conn, error) { - // Build the target address for the agent's HTTP API server - apiServerAddr := fmt.Sprintf("127.0.0.1:%d", workspacesdk.AgentHTTPAPIServerPort) - wsURL := fmt.Sprintf("ws://%s/api/v0/immortal-stream/%s", apiServerAddr, streamID) - - // Create WebSocket connection using the agent's tailnet connection - dialOptions := &websocket.DialOptions{ - HTTPClient: &http.Client{ - Transport: &http.Transport{ - DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) { - return agentConn.DialContext(dialCtx, network, addr) - }, - }, - }, - CompressionMode: websocket.CompressionDisabled, - } - - // Connect to the WebSocket endpoint - conn, _, err := websocket.Dial(ctx, wsURL, dialOptions) - if err != nil { - return nil, xerrors.Errorf("dial immortal stream WebSocket: %w", err) - } - - // Convert WebSocket to net.Conn for SSH usage - // The immortal stream's BackedPipe handles reconnection automatically - netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary) - - return netConn, nil +// connectToImmortalStreamWebSocket connects to an immortal stream via a reconnecting wrapper. +func connectToImmortalStreamWebSocket(ctx context.Context, agentConn workspacesdk.AgentConn, streamID uuid.UUID, logger slog.Logger) (net.Conn, error) { + // No generic way to recreate the stream here without knowing the target port. + // The SSH caller provides a refresh callback using the known target port (1). + return newImmortalReconnectingConn(ctx, agentConn, streamID, logger, nil) } // isNetworkError checks if an error is a temporary network error diff --git a/codersdk/workspacesdk/agentconn.go b/codersdk/workspacesdk/agentconn.go index 36dd471712a3c..21d036edaccde 100644 --- a/codersdk/workspacesdk/agentconn.go +++ b/codersdk/workspacesdk/agentconn.go @@ -67,6 +67,11 @@ type AgentConn interface { SSHOnPort(ctx context.Context, port uint16) (*gonet.TCPConn, error) Speedtest(ctx context.Context, direction speedtest.Direction, duration time.Duration) ([]speedtest.Result, error) WatchContainers(ctx context.Context, logger slog.Logger) (<-chan codersdk.WorkspaceAgentListContainersResponse, io.Closer, error) + + // Agent HTTP API: Immortal Streams + ImmortalStreams(ctx context.Context) ([]codersdk.ImmortalStream, error) + CreateImmortalStream(ctx context.Context, req codersdk.CreateImmortalStreamRequest) (codersdk.ImmortalStream, error) + DeleteImmortalStream(ctx context.Context, streamID uuid.UUID) error } // AgentConn represents a connection to a workspace agent. @@ -314,7 +319,7 @@ func (c *agentConn) ListeningPorts(ctx context.Context) (codersdk.WorkspaceAgent } // ImmortalStreams lists the immortal streams that are currently active in the workspace. -func (c *AgentConn) ImmortalStreams(ctx context.Context) ([]codersdk.ImmortalStream, error) { +func (c *agentConn) ImmortalStreams(ctx context.Context) ([]codersdk.ImmortalStream, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() res, err := c.apiRequest(ctx, http.MethodGet, "/api/v0/immortal-stream", nil) @@ -331,7 +336,7 @@ func (c *AgentConn) ImmortalStreams(ctx context.Context) ([]codersdk.ImmortalStr } // CreateImmortalStream creates a new immortal stream to the specified port. -func (c *AgentConn) CreateImmortalStream(ctx context.Context, req codersdk.CreateImmortalStreamRequest) (codersdk.ImmortalStream, error) { +func (c *agentConn) CreateImmortalStream(ctx context.Context, req codersdk.CreateImmortalStreamRequest) (codersdk.ImmortalStream, error) { ctx, span := tracing.StartSpan(ctx) defer span.End() @@ -363,7 +368,7 @@ func (c *AgentConn) CreateImmortalStream(ctx context.Context, req codersdk.Creat } // DeleteImmortalStream deletes an immortal stream by ID. -func (c *AgentConn) DeleteImmortalStream(ctx context.Context, streamID uuid.UUID) error { +func (c *agentConn) DeleteImmortalStream(ctx context.Context, streamID uuid.UUID) error { ctx, span := tracing.StartSpan(ctx) defer span.End() diff --git a/codersdk/workspacesdk/agentconnmock/agentconnmock.go b/codersdk/workspacesdk/agentconnmock/agentconnmock.go index eb55bb27938c0..fa9d8ed95eaf7 100644 --- a/codersdk/workspacesdk/agentconnmock/agentconnmock.go +++ b/codersdk/workspacesdk/agentconnmock/agentconnmock.go @@ -81,6 +81,21 @@ func (mr *MockAgentConnMockRecorder) Close() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAgentConn)(nil).Close)) } +// CreateImmortalStream mocks base method. +func (m *MockAgentConn) CreateImmortalStream(ctx context.Context, req codersdk.CreateImmortalStreamRequest) (codersdk.ImmortalStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateImmortalStream", ctx, req) + ret0, _ := ret[0].(codersdk.ImmortalStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateImmortalStream indicates an expected call of CreateImmortalStream. +func (mr *MockAgentConnMockRecorder) CreateImmortalStream(ctx, req any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateImmortalStream", reflect.TypeOf((*MockAgentConn)(nil).CreateImmortalStream), ctx, req) +} + // DebugLogs mocks base method. func (m *MockAgentConn) DebugLogs(ctx context.Context) ([]byte, error) { m.ctrl.T.Helper() @@ -126,6 +141,20 @@ func (mr *MockAgentConnMockRecorder) DebugManifest(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugManifest", reflect.TypeOf((*MockAgentConn)(nil).DebugManifest), ctx) } +// DeleteImmortalStream mocks base method. +func (m *MockAgentConn) DeleteImmortalStream(ctx context.Context, streamID uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteImmortalStream", ctx, streamID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteImmortalStream indicates an expected call of DeleteImmortalStream. +func (mr *MockAgentConnMockRecorder) DeleteImmortalStream(ctx, streamID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteImmortalStream", reflect.TypeOf((*MockAgentConn)(nil).DeleteImmortalStream), ctx, streamID) +} + // DialContext mocks base method. func (m *MockAgentConn) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { m.ctrl.T.Helper() @@ -155,6 +184,21 @@ func (mr *MockAgentConnMockRecorder) GetPeerDiagnostics() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerDiagnostics", reflect.TypeOf((*MockAgentConn)(nil).GetPeerDiagnostics)) } +// ImmortalStreams mocks base method. +func (m *MockAgentConn) ImmortalStreams(ctx context.Context) ([]codersdk.ImmortalStream, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ImmortalStreams", ctx) + ret0, _ := ret[0].([]codersdk.ImmortalStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ImmortalStreams indicates an expected call of ImmortalStreams. +func (mr *MockAgentConnMockRecorder) ImmortalStreams(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ImmortalStreams", reflect.TypeOf((*MockAgentConn)(nil).ImmortalStreams), ctx) +} + // ListContainers mocks base method. func (m *MockAgentConn) ListContainers(ctx context.Context) (codersdk.WorkspaceAgentListContainersResponse, error) { m.ctrl.T.Helper()