Skip to content

Commit 4996ebe

Browse files
committed
WIP
1 parent 2355f74 commit 4996ebe

File tree

2 files changed

+125
-96
lines changed

2 files changed

+125
-96
lines changed

agent/agent.go

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -70,44 +70,13 @@ const (
7070
EnvProcOOMScore = "CODER_PROC_OOM_SCORE"
7171
)
7272

73-
// agentImmortalDialer is a custom dialer for immortal streams that can
74-
// connect to the agent's own services via tailnet addresses.
73+
// agentImmortalDialer wraps the standard dialer for immortal streams.
74+
// Agent services are available on both tailnet and localhost interfaces.
7575
type agentImmortalDialer struct {
76-
agent *agent
7776
standardDialer *net.Dialer
7877
}
7978

8079
func (d *agentImmortalDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
81-
host, portStr, err := net.SplitHostPort(address)
82-
if err != nil {
83-
return nil, xerrors.Errorf("split host port %q: %w", address, err)
84-
}
85-
86-
port, err := strconv.Atoi(portStr)
87-
if err != nil {
88-
return nil, xerrors.Errorf("parse port %q: %w", portStr, err)
89-
}
90-
91-
// Check if this is a connection to one of the agent's own services
92-
isLocalhost := host == "localhost" || host == "127.0.0.1" || host == "::1"
93-
isAgentPort := port == int(workspacesdk.AgentSSHPort) || port == int(workspacesdk.AgentHTTPAPIServerPort) ||
94-
port == int(workspacesdk.AgentReconnectingPTYPort) || port == int(workspacesdk.AgentSpeedtestPort)
95-
96-
if isLocalhost && isAgentPort {
97-
// Get the agent ID from the current manifest
98-
manifest := d.agent.manifest.Load()
99-
if manifest == nil || manifest.AgentID == uuid.Nil {
100-
// Fallback to standard dialing if no manifest available yet
101-
return d.standardDialer.DialContext(ctx, network, address)
102-
}
103-
104-
// Connect to the agent's own tailnet address instead of localhost
105-
agentAddr := tailnet.TailscaleServicePrefix.AddrFromUUID(manifest.AgentID)
106-
agentAddress := net.JoinHostPort(agentAddr.String(), portStr)
107-
return d.standardDialer.DialContext(ctx, network, agentAddress)
108-
}
109-
110-
// For other addresses, use standard dialing
11180
return d.standardDialer.DialContext(ctx, network, address)
11281
}
11382

@@ -392,10 +361,8 @@ func (a *agent) init() {
392361

393362
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
394363

395-
// Initialize immortal streams manager with a custom dialer
396-
// that can connect to the agent's own services
364+
// Initialize immortal streams manager
397365
immortalDialer := &agentImmortalDialer{
398-
agent: a,
399366
standardDialer: &net.Dialer{},
400367
}
401368
a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), immortalDialer)
@@ -1531,6 +1498,7 @@ func (a *agent) createTailnet(
15311498
}
15321499

15331500
for _, port := range []int{workspacesdk.AgentSSHPort, workspacesdk.AgentStandardSSHPort} {
1501+
// Listen on tailnet interface for external connections
15341502
sshListener, err := network.Listen("tcp", ":"+strconv.Itoa(port))
15351503
if err != nil {
15361504
return nil, xerrors.Errorf("listen on the ssh port (%v): %w", port, err)
@@ -1546,6 +1514,25 @@ func (a *agent) createTailnet(
15461514
}); err != nil {
15471515
return nil, err
15481516
}
1517+
1518+
// Also listen on localhost for immortal streams (only for SSH port 1)
1519+
if port == workspacesdk.AgentSSHPort {
1520+
localhostListener, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(port))
1521+
if err != nil {
1522+
return nil, xerrors.Errorf("listen on localhost ssh port (%v): %w", port, err)
1523+
}
1524+
// nolint:revive // We do want to run the deferred functions when createTailnet returns.
1525+
defer func() {
1526+
if err != nil {
1527+
_ = localhostListener.Close()
1528+
}
1529+
}()
1530+
if err = a.trackGoroutine(func() {
1531+
_ = a.sshServer.Serve(localhostListener)
1532+
}); err != nil {
1533+
return nil, err
1534+
}
1535+
}
15491536
}
15501537

15511538
reconnectingPTYListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentReconnectingPTYPort))
@@ -1616,6 +1603,7 @@ func (a *agent) createTailnet(
16161603
return nil, err
16171604
}
16181605

1606+
// Listen on tailnet interface for external connections
16191607
apiListener, err := network.Listen("tcp", ":"+strconv.Itoa(workspacesdk.AgentHTTPAPIServerPort))
16201608
if err != nil {
16211609
return nil, xerrors.Errorf("api listener: %w", err)
@@ -1652,6 +1640,43 @@ func (a *agent) createTailnet(
16521640
return nil, err
16531641
}
16541642

1643+
// Also listen on localhost for immortal streams WebSocket connections
1644+
localhostAPIListener, err := net.Listen("tcp", "127.0.0.1:"+strconv.Itoa(workspacesdk.AgentHTTPAPIServerPort))
1645+
if err != nil {
1646+
return nil, xerrors.Errorf("localhost api listener: %w", err)
1647+
}
1648+
defer func() {
1649+
if err != nil {
1650+
_ = localhostAPIListener.Close()
1651+
}
1652+
}()
1653+
if err = a.trackGoroutine(func() {
1654+
defer localhostAPIListener.Close()
1655+
apiHandler := a.apiHandler()
1656+
server := &http.Server{
1657+
BaseContext: func(net.Listener) context.Context { return ctx },
1658+
Handler: apiHandler,
1659+
ReadTimeout: 20 * time.Second,
1660+
ReadHeaderTimeout: 20 * time.Second,
1661+
WriteTimeout: 20 * time.Second,
1662+
ErrorLog: slog.Stdlib(ctx, a.logger.Named("http_api_server_localhost"), slog.LevelInfo),
1663+
}
1664+
go func() {
1665+
select {
1666+
case <-ctx.Done():
1667+
case <-a.hardCtx.Done():
1668+
}
1669+
_ = server.Close()
1670+
}()
1671+
1672+
apiServErr := server.Serve(localhostAPIListener)
1673+
if apiServErr != nil && !xerrors.Is(apiServErr, http.ErrServerClosed) && !strings.Contains(apiServErr.Error(), "use of closed network connection") {
1674+
a.logger.Critical(ctx, "serve localhost HTTP API server", slog.Error(apiServErr))
1675+
}
1676+
}); err != nil {
1677+
return nil, err
1678+
}
1679+
16551680
return network, nil
16561681
}
16571682

cli/ssh.go

Lines changed: 64 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,10 @@ func (r *RootCmd) ssh() *serpent.Command {
440440
// Connect to the immortal stream via WebSocket
441441
rawSSH, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger)
442442
if err != nil {
443-
// Clean up the stream if connection fails
444-
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
443+
// Only clean up the stream if it's a permanent failure
444+
if !isNetworkError(err) {
445+
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
446+
}
445447
return xerrors.Errorf("connect to immortal stream: %w", err)
446448
}
447449
}
@@ -481,25 +483,25 @@ func (r *RootCmd) ssh() *serpent.Command {
481483
}
482484
}
483485

484-
// Set up cleanup for immortal stream
486+
// Set up signal-based cleanup for immortal stream
487+
// Only delete on explicit user termination (SIGINT, SIGTERM), not network errors
485488
if immortalStreamClient != nil && streamID != nil {
486-
defer func() {
487-
if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil {
488-
logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err))
489-
}
489+
// Create a signal-only context for cleanup
490+
signalCtx, signalStop := inv.SignalNotifyContext(context.Background(), StopSignals...)
491+
defer signalStop()
492+
493+
go func() {
494+
<-signalCtx.Done()
495+
// User sent termination signal - clean up the stream
496+
_ = immortalStreamClient.deleteStream(context.Background(), *streamID)
490497
}()
491498
}
492499

493500
wg.Add(1)
494501
go func() {
495502
defer wg.Done()
496503
watchAndClose(ctx, func() error {
497-
// Clean up immortal stream on termination
498-
if immortalStreamClient != nil && streamID != nil {
499-
if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil {
500-
logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err))
501-
}
502-
}
504+
// Don't delete immortal stream here - let signal handler do it
503505
stack.close(xerrors.New("watchAndClose"))
504506
return nil
505507
}, logger, client, workspace, errCh)
@@ -557,8 +559,10 @@ func (r *RootCmd) ssh() *serpent.Command {
557559
// Connect to the immortal stream and create SSH client
558560
rawConn, err := connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger)
559561
if err != nil {
560-
// Clean up the stream if connection fails
561-
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
562+
// Only clean up the stream if it's a permanent failure
563+
if !isNetworkError(err) {
564+
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
565+
}
562566
return xerrors.Errorf("connect to immortal stream: %w", err)
563567
}
564568

@@ -569,7 +573,10 @@ func (r *RootCmd) ssh() *serpent.Command {
569573
})
570574
if err != nil {
571575
rawConn.Close()
572-
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
576+
// Only clean up the stream if it's a permanent failure
577+
if !isNetworkError(err) {
578+
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
579+
}
573580
return xerrors.Errorf("ssh handshake over immortal stream: %w", err)
574581
}
575582

@@ -603,12 +610,17 @@ func (r *RootCmd) ssh() *serpent.Command {
603610
}
604611
}
605612

606-
// Set up cleanup for immortal stream in regular SSH mode
613+
// Set up signal-based cleanup for immortal stream
614+
// Only delete on explicit user termination (SIGINT, SIGTERM), not network errors
607615
if immortalStreamClient != nil && streamID != nil {
608-
defer func() {
609-
if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil {
610-
logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err))
611-
}
616+
// Create a signal-only context for cleanup
617+
signalCtx, signalStop := inv.SignalNotifyContext(context.Background(), StopSignals...)
618+
defer signalStop()
619+
620+
go func() {
621+
<-signalCtx.Done()
622+
// User sent termination signal - clean up the stream
623+
_ = immortalStreamClient.deleteStream(context.Background(), *streamID)
612624
}()
613625
}
614626

@@ -618,12 +630,7 @@ func (r *RootCmd) ssh() *serpent.Command {
618630
watchAndClose(
619631
ctx,
620632
func() error {
621-
// Clean up immortal stream on termination
622-
if immortalStreamClient != nil && streamID != nil {
623-
if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil {
624-
logger.Error(context.Background(), "failed to cleanup immortal stream on termination", slog.Error(err))
625-
}
626-
}
633+
// Don't delete immortal stream here - let signal handler do it
627634
stack.close(xerrors.New("watchAndClose"))
628635
return nil
629636
},
@@ -923,66 +930,63 @@ func (r *RootCmd) ssh() *serpent.Command {
923930
return cmd
924931
}
925932

926-
// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket and returns a net.Conn
933+
// connectToImmortalStreamWebSocket connects to an immortal stream via WebSocket
934+
// The immortal stream infrastructure handles reconnection automatically
927935
func connectToImmortalStreamWebSocket(ctx context.Context, agentConn *workspacesdk.AgentConn, streamID uuid.UUID, logger slog.Logger) (net.Conn, error) {
928936
// Build the target address for the agent's HTTP API server
929-
// We'll let the WebSocket dialer handle the actual connection through the agent
930937
apiServerAddr := fmt.Sprintf("127.0.0.1:%d", workspacesdk.AgentHTTPAPIServerPort)
931938
wsURL := fmt.Sprintf("ws://%s/api/v0/immortal-stream/%s", apiServerAddr, streamID)
932939

933940
// Create WebSocket connection using the agent's tailnet connection
934-
// The key is to use a custom dialer that routes through the agent connection
935941
dialOptions := &websocket.DialOptions{
936942
HTTPClient: &http.Client{
937943
Transport: &http.Transport{
938944
DialContext: func(dialCtx context.Context, network, addr string) (net.Conn, error) {
939-
// Route all connections through the agent connection
940-
// The agent connection will handle routing to the correct internal address
941-
942-
conn, err := agentConn.DialContext(dialCtx, network, addr)
943-
if err != nil {
944-
return nil, err
945-
}
946-
947-
return conn, nil
945+
return agentConn.DialContext(dialCtx, network, addr)
948946
},
949947
},
950948
},
951-
// Disable compression for raw TCP data
952949
CompressionMode: websocket.CompressionDisabled,
953950
}
954951

955952
// Connect to the WebSocket endpoint
956-
conn, res, err := websocket.Dial(ctx, wsURL, dialOptions)
953+
conn, _, err := websocket.Dial(ctx, wsURL, dialOptions)
957954
if err != nil {
958-
if res != nil {
959-
logger.Error(ctx, "WebSocket dial failed",
960-
slog.F("stream_id", streamID),
961-
slog.F("websocket_url", wsURL),
962-
slog.F("status", res.StatusCode),
963-
slog.F("status_text", res.Status),
964-
slog.Error(err))
965-
} else {
966-
logger.Error(ctx, "WebSocket dial failed (no response)",
967-
slog.F("stream_id", streamID),
968-
slog.F("websocket_url", wsURL),
969-
slog.Error(err))
970-
}
971955
return nil, xerrors.Errorf("dial immortal stream WebSocket: %w", err)
972956
}
973957

974-
logger.Info(ctx, "successfully connected to immortal stream WebSocket",
975-
slog.F("stream_id", streamID))
976-
977958
// Convert WebSocket to net.Conn for SSH usage
978-
// Use MessageBinary for raw TCP data transport
959+
// The immortal stream's BackedPipe handles reconnection automatically
979960
netConn := websocket.NetConn(ctx, conn, websocket.MessageBinary)
980961

981-
logger.Debug(ctx, "converted WebSocket to net.Conn for SSH usage")
982-
983962
return netConn, nil
984963
}
985964

965+
// isNetworkError checks if an error is a temporary network error
966+
func isNetworkError(err error) bool {
967+
if err == nil {
968+
return false
969+
}
970+
971+
errStr := err.Error()
972+
networkErrors := []string{
973+
"connection refused",
974+
"network is unreachable",
975+
"connection reset",
976+
"broken pipe",
977+
"timeout",
978+
"no route to host",
979+
}
980+
981+
for _, netErr := range networkErrors {
982+
if strings.Contains(errStr, netErr) {
983+
return true
984+
}
985+
}
986+
987+
return false
988+
}
989+
986990
// findWorkspaceAndAgentByHostname parses the hostname from the commandline and finds the workspace and agent it
987991
// corresponds to, taking into account any name prefixes or suffixes configured (e.g. myworkspace.coder, or
988992
// vscode-coder--myusername--myworkspace).

0 commit comments

Comments
 (0)