Skip to content

Commit e73338d

Browse files
committed
WIP
1 parent dccaf9b commit e73338d

File tree

1 file changed

+92
-8
lines changed

1 file changed

+92
-8
lines changed

cli/portforward.go

Lines changed: 92 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"sync"
1414
"syscall"
1515

16+
"github.com/google/uuid"
1617
"golang.org/x/xerrors"
1718

1819
"cdr.dev/slog"
@@ -152,15 +153,15 @@ func (r *RootCmd) portForward() *serpent.Command {
152153
// first, opportunistically try to listen on IPv6
153154
spec6 := spec
154155
spec6.listenHost = ipv6Loopback
155-
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger)
156+
l6, err6 := listenAndPortForward(ctx, inv, conn, wg, spec6, logger, immortal, immortalFallback, client, workspaceAgent.ID)
156157
if err6 != nil {
157158
logger.Info(ctx, "failed to opportunistically listen on IPv6", slog.F("spec", spec), slog.Error(err6))
158159
} else {
159160
listeners = append(listeners, l6)
160161
}
161162
spec.listenHost = ipv4Loopback
162163
}
163-
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger)
164+
l, err := listenAndPortForward(ctx, inv, conn, wg, spec, logger, immortal, immortalFallback, client, workspaceAgent.ID)
164165
if err != nil {
165166
logger.Error(ctx, "failed to listen", slog.F("spec", spec), slog.Error(err))
166167
return err
@@ -242,6 +243,10 @@ func listenAndPortForward(
242243
wg *sync.WaitGroup,
243244
spec portForwardSpec,
244245
logger slog.Logger,
246+
immortal bool,
247+
immortalFallback bool,
248+
client *codersdk.Client,
249+
agentID uuid.UUID,
245250
) (net.Listener, error) {
246251
logger = logger.With(
247252
slog.F("network", spec.network),
@@ -281,17 +286,96 @@ func listenAndPortForward(
281286

282287
go func(netConn net.Conn) {
283288
defer netConn.Close()
284-
remoteConn, err := conn.DialContext(ctx, spec.network, dialAddress)
285-
if err != nil {
286-
_, _ = fmt.Fprintf(inv.Stderr,
287-
"Failed to dial '%s://%s' in workspace: %s\n",
288-
spec.network, dialAddress, err)
289-
return
289+
290+
var remoteConn net.Conn
291+
var immortalStreamClient *immortalStreamClient
292+
var streamID *uuid.UUID
293+
294+
// Only use immortal streams for TCP connections
295+
if immortal && spec.network == "tcp" {
296+
// Create immortal stream client
297+
immortalStreamClient = newImmortalStreamClient(client, agentID, logger)
298+
299+
// Create immortal stream to the target port
300+
stream, err := immortalStreamClient.createStream(ctx, int(spec.dialPort))
301+
if err != nil {
302+
logger.Error(ctx, "failed to create immortal stream for port forward",
303+
slog.Error(err),
304+
slog.F("agent_id", agentID),
305+
slog.F("target_port", spec.dialPort),
306+
slog.F("immortal_fallback_enabled", immortalFallback))
307+
308+
shouldFallback := immortalFallback && (strings.Contains(err.Error(), "Too many immortal streams") ||
309+
strings.Contains(err.Error(), "The connection was refused"))
310+
311+
if shouldFallback {
312+
if strings.Contains(err.Error(), "Too many immortal streams") {
313+
logger.Warn(ctx, "too many immortal streams, falling back to regular port forward",
314+
slog.F("max_streams", "32"),
315+
slog.F("target_port", spec.dialPort))
316+
} else {
317+
logger.Warn(ctx, "service not available, falling back to regular port forward",
318+
slog.F("reason", "connection_refused"),
319+
slog.F("target_port", spec.dialPort))
320+
}
321+
logger.Debug(ctx, "attempting fallback to regular port forward")
322+
remoteConn, err = conn.DialContext(ctx, spec.network, dialAddress)
323+
if err != nil {
324+
logger.Error(ctx, "fallback port forward also failed", slog.Error(err))
325+
_, _ = fmt.Fprintf(inv.Stderr,
326+
"Failed to dial '%s://%s' in workspace: %s\n",
327+
spec.network, dialAddress, err)
328+
return
329+
}
330+
logger.Debug(ctx, "successfully connected via regular port forward fallback")
331+
} else {
332+
_, _ = fmt.Fprintf(inv.Stderr,
333+
"Failed to create immortal stream for '%s://%s' in workspace: %s\n",
334+
spec.network, dialAddress, err)
335+
return
336+
}
337+
} else {
338+
streamID = &stream.ID
339+
logger.Debug(ctx, "created immortal stream for port forward",
340+
slog.F("stream_name", stream.Name),
341+
slog.F("stream_id", stream.ID),
342+
slog.F("target_port", spec.dialPort))
343+
344+
// Connect to the immortal stream via WebSocket
345+
remoteConn, err = connectToImmortalStreamWebSocket(ctx, conn, stream.ID, logger)
346+
if err != nil {
347+
// Clean up the stream if connection fails
348+
_ = immortalStreamClient.deleteStream(ctx, stream.ID)
349+
_, _ = fmt.Fprintf(inv.Stderr,
350+
"Failed to connect to immortal stream for '%s://%s' in workspace: %s\n",
351+
spec.network, dialAddress, err)
352+
return
353+
}
354+
}
355+
} else {
356+
// Use regular connection for UDP or when immortal is disabled
357+
remoteConn, err = conn.DialContext(ctx, spec.network, dialAddress)
358+
if err != nil {
359+
_, _ = fmt.Fprintf(inv.Stderr,
360+
"Failed to dial '%s://%s' in workspace: %s\n",
361+
spec.network, dialAddress, err)
362+
return
363+
}
290364
}
365+
291366
defer remoteConn.Close()
292367
logger.Debug(ctx,
293368
"dialed remote", slog.F("remote_addr", netConn.RemoteAddr()))
294369

370+
// Set up cleanup for immortal stream
371+
if immortalStreamClient != nil && streamID != nil {
372+
defer func() {
373+
if err := immortalStreamClient.deleteStream(context.Background(), *streamID); err != nil {
374+
logger.Error(context.Background(), "failed to cleanup immortal stream", slog.Error(err))
375+
}
376+
}()
377+
}
378+
295379
agentssh.Bicopy(ctx, netConn, remoteConn)
296380
logger.Debug(ctx,
297381
"connection closing", slog.F("remote_addr", netConn.RemoteAddr()))

0 commit comments

Comments
 (0)