Skip to content

Commit 9bcaa2f

Browse files
committed
WIP
1 parent 57b68a4 commit 9bcaa2f

File tree

4 files changed

+156
-124
lines changed

4 files changed

+156
-124
lines changed

agent/immortalstreams/manager.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ func (m *Manager) CreateStream(ctx context.Context, port int) (*codersdk.Immorta
8080
name,
8181
port,
8282
m.logger.With(slog.F("stream_id", id), slog.F("stream_name", name)),
83+
m.dialer,
8384
)
8485

8586
// Start the stream

agent/immortalstreams/manager_test.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,24 @@ func TestManager_CreateStream(t *testing.T) {
4040
if err != nil {
4141
return
4242
}
43-
// Just echo for testing
44-
go func() {
45-
defer conn.Close()
46-
_, _ = io.Copy(conn, conn)
47-
}()
43+
// Just echo for testing with proper cleanup
44+
go func(c net.Conn) {
45+
defer c.Close()
46+
// Use a buffer to avoid blocking indefinitely
47+
buf := make([]byte, 1024)
48+
for {
49+
n, err := c.Read(buf)
50+
if err != nil {
51+
return
52+
}
53+
if n > 0 {
54+
_, err = c.Write(buf[:n])
55+
if err != nil {
56+
return
57+
}
58+
}
59+
}
60+
}(conn)
4861
}
4962
}()
5063

agent/immortalstreams/stream.go

Lines changed: 125 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package immortalstreams
33
import (
44
"context"
55
"errors"
6+
"fmt"
67
"io"
78
"sync"
89
"time"
@@ -22,6 +23,7 @@ type Stream struct {
2223
port int
2324
createdAt time.Time
2425
logger slog.Logger
26+
dialer Dialer
2527

2628
mu sync.RWMutex
2729
localConn io.ReadWriteCloser
@@ -56,8 +58,7 @@ type Stream struct {
5658

5759
// reconnectRequest represents a pending reconnection request
5860
type reconnectRequest struct {
59-
writerSeqNum uint64
60-
response chan reconnectResponse
61+
response chan reconnectResponse
6162
}
6263

6364
// reconnectResponse represents a reconnection response
@@ -67,83 +68,75 @@ type reconnectResponse struct {
6768
err error
6869
}
6970

71+
// streamReconnector implements the backedpipe.Reconnector interface for Stream
72+
type streamReconnector struct {
73+
stream *Stream
74+
// Track the current client connection so we can close it during reconnect
75+
mu sync.Mutex
76+
currentConn io.ReadWriteCloser
77+
}
78+
79+
// Reconnect implements the backedpipe.Reconnector interface
80+
func (sr *streamReconnector) Reconnect(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
81+
sr.stream.logger.Info(context.Background(), "reconnector attempting to dial local connection",
82+
slog.F("port", sr.stream.port),
83+
slog.F("writer_seq", writerSeqNum))
84+
85+
// Dial the local TCP port directly
86+
conn, err := sr.stream.dialer.DialContext(ctx, "tcp", fmt.Sprintf("localhost:%d", sr.stream.port))
87+
if err != nil {
88+
sr.stream.logger.Warn(context.Background(), "failed to dial local connection",
89+
slog.Error(err),
90+
slog.F("port", sr.stream.port))
91+
return nil, 0, err
92+
}
93+
94+
sr.stream.logger.Info(context.Background(), "successfully dialed local connection",
95+
slog.F("port", sr.stream.port))
96+
97+
// Store the new connection for tracking
98+
sr.mu.Lock()
99+
sr.currentConn = conn
100+
sr.mu.Unlock()
101+
102+
// Update stream state
103+
sr.stream.mu.Lock()
104+
sr.stream.connected = true
105+
sr.stream.lastConnectionAt = time.Now()
106+
if sr.stream.reconnectCond != nil {
107+
sr.stream.reconnectCond.Broadcast()
108+
}
109+
sr.stream.mu.Unlock()
110+
111+
return conn, 0, nil // Start from sequence 0 for new connections
112+
}
113+
70114
// NewStream creates a new immortal stream
71-
func NewStream(id uuid.UUID, name string, port int, logger slog.Logger) *Stream {
115+
func NewStream(id uuid.UUID, name string, port int, logger slog.Logger, dialer Dialer) *Stream {
72116
stream := &Stream{
73117
id: id,
74118
name: name,
75119
port: port,
76120
createdAt: time.Now(),
77121
logger: logger,
122+
dialer: dialer,
78123
disconnectChan: make(chan struct{}, 1),
79124
shutdownChan: make(chan struct{}),
80125
reconnectReq: make(chan struct{}, 1),
81126
}
82127
stream.reconnectCond = sync.NewCond(&stream.mu)
83128

84-
// Create a reconnect function that waits for a client connection
85-
reconnectFn := func(ctx context.Context, writerSeqNum uint64) (io.ReadWriteCloser, uint64, error) {
86-
// Wait for HandleReconnect to be called with a new connection
87-
responseChan := make(chan reconnectResponse, 1)
88-
89-
stream.mu.Lock()
90-
stream.pendingReconnect = &reconnectRequest{
91-
writerSeqNum: writerSeqNum,
92-
response: responseChan,
93-
}
94-
stream.handshakePending = true
95-
// Mark disconnected if we previously had a client connection
96-
if stream.connected {
97-
stream.connected = false
98-
stream.lastDisconnectionAt = time.Now()
99-
}
100-
stream.logger.Info(context.Background(), "pending reconnect set",
101-
slog.F("writer_seq", writerSeqNum))
102-
// Signal waiters a reconnect request is pending
103-
stream.reconnectCond.Broadcast()
104-
stream.mu.Unlock()
105-
106-
// Fast path: if the stream is already shutting down, abort immediately
107-
select {
108-
case <-stream.shutdownChan:
109-
stream.mu.Lock()
110-
// Clear the pending request since we're aborting
111-
if stream.pendingReconnect != nil {
112-
stream.pendingReconnect = nil
113-
}
114-
stream.mu.Unlock()
115-
return nil, 0, xerrors.New("stream is shutting down")
116-
default:
117-
}
118-
119-
// Wait for response from HandleReconnect or context cancellation
120-
stream.logger.Info(context.Background(), "reconnect function waiting for response")
121-
select {
122-
case resp := <-responseChan:
123-
stream.logger.Info(context.Background(), "reconnect function got response",
124-
slog.F("has_conn", resp.conn != nil),
125-
slog.F("read_seq", resp.readSeq),
126-
slog.Error(resp.err))
127-
return resp.conn, resp.readSeq, resp.err
128-
case <-ctx.Done():
129-
// Context was canceled, clear pending request and return error
130-
stream.mu.Lock()
131-
stream.pendingReconnect = nil
132-
stream.handshakePending = false
133-
stream.mu.Unlock()
134-
return nil, 0, ctx.Err()
135-
case <-stream.shutdownChan:
136-
// Stream is being shut down, clear pending request and return error
137-
stream.mu.Lock()
138-
stream.pendingReconnect = nil
139-
stream.handshakePending = false
140-
stream.mu.Unlock()
141-
return nil, 0, xerrors.New("stream is shutting down")
142-
}
129+
// Create BackedPipe with background context and reconnector
130+
reconnector := &streamReconnector{
131+
stream: stream,
143132
}
133+
stream.pipe = backedpipe.NewBackedPipe(context.Background(), reconnector)
144134

145-
// Create BackedPipe with background context
146-
stream.pipe = backedpipe.NewBackedPipe(context.Background(), reconnectFn)
135+
// Initiate the first connection
136+
if err := stream.pipe.Connect(); err != nil {
137+
stream.logger.Warn(context.Background(), "failed to connect pipe initially", slog.Error(err))
138+
// Continue anyway - the pipe will retry connections as needed
139+
}
147140

148141
// Start reconnect worker: dedupe pokes and call ForceReconnect when safe.
149142
go func() {
@@ -240,22 +233,9 @@ func (s *Stream) HandleReconnect(clientConn io.ReadWriteCloser, readSeqNum uint6
240233
s.mu.Unlock()
241234
respCh <- reconnectResponse{conn: clientConn, readSeq: readSeqNum, err: nil}
242235

243-
// Wait until the pipe reports a connected state so the handshake fully completes.
244-
// Use a bounded timeout to avoid hanging forever in pathological cases.
245-
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
246-
err := s.pipe.WaitForConnection(ctx)
247-
cancel()
248-
if err != nil {
249-
s.mu.Lock()
250-
s.connected = false
251-
if s.reconnectCond != nil {
252-
s.reconnectCond.Broadcast()
253-
}
254-
s.mu.Unlock()
255-
s.logger.Warn(context.Background(), "failed to connect backed pipe", slog.Error(err))
256-
return xerrors.Errorf("failed to establish connection: %w", err)
257-
}
258-
236+
// The reconnector interface handles the connection establishment.
237+
// By the time we respond to the reconnect request, the connection should be established.
238+
// We just need to update our state to reflect the successful connection.
259239
s.mu.Lock()
260240
s.lastConnectionAt = time.Now()
261241
s.connected = true
@@ -333,23 +313,35 @@ func (s *Stream) Close() error {
333313
s.handshakePending = false
334314
}
335315

336-
// Close the backed pipe
337-
if s.pipe != nil {
338-
if err := s.pipe.Close(); err != nil {
339-
s.logger.Warn(context.Background(), "failed to close backed pipe", slog.Error(err))
340-
}
341-
}
342-
343-
// Close connections
316+
// Close connections first to unblock io.Copy operations
344317
if s.localConn != nil {
345318
if err := s.localConn.Close(); err != nil {
346319
s.logger.Warn(context.Background(), "failed to close local connection", slog.Error(err))
347320
}
348321
}
349322

350-
// Wait for goroutines to finish
323+
// Then close the backed pipe
324+
if s.pipe != nil {
325+
if err := s.pipe.Close(); err != nil {
326+
s.logger.Warn(context.Background(), "failed to close backed pipe", slog.Error(err))
327+
}
328+
}
329+
330+
// Wait for goroutines to finish with a timeout
351331
s.mu.Unlock()
352-
s.goroutines.Wait()
332+
done := make(chan struct{})
333+
go func() {
334+
defer close(done)
335+
s.goroutines.Wait()
336+
}()
337+
338+
select {
339+
case <-done:
340+
// Goroutines finished normally
341+
case <-time.After(5 * time.Second):
342+
// Timeout - log warning but continue
343+
s.logger.Warn(context.Background(), "timeout waiting for stream goroutines to finish during close")
344+
}
353345
s.mu.Lock()
354346

355347
return nil
@@ -403,8 +395,17 @@ func (s *Stream) startCopyingLocked() {
403395
defer s.goroutines.Done()
404396

405397
_, err := io.Copy(s.pipe, s.localConn)
406-
if err != nil && !xerrors.Is(err, io.EOF) && !xerrors.Is(err, io.ErrClosedPipe) {
407-
s.logger.Debug(context.Background(), "error copying from local to pipe", slog.Error(err))
398+
if err != nil {
399+
// Handle different error types appropriately
400+
if xerrors.Is(err, io.EOF) {
401+
s.logger.Debug(context.Background(), "local connection closed (EOF)")
402+
} else if xerrors.Is(err, io.ErrClosedPipe) || xerrors.Is(err, backedpipe.ErrPipeClosed) {
403+
s.logger.Debug(context.Background(), "pipe closed during copy", slog.Error(err))
404+
} else if xerrors.Is(err, backedpipe.ErrWriterClosed) {
405+
s.logger.Debug(context.Background(), "writer closed during copy", slog.Error(err))
406+
} else {
407+
s.logger.Debug(context.Background(), "error copying from local to pipe", slog.Error(err))
408+
}
408409
}
409410

410411
// Local connection closed, signal disconnection
@@ -426,13 +427,35 @@ func (s *Stream) startCopyingLocked() {
426427
for {
427428
// Use a buffer for copying
428429
n, err := s.pipe.Read(buf)
429-
// Log significant events
430-
if errors.Is(err, io.EOF) {
431-
s.logger.Debug(context.Background(), "got EOF from pipe")
432-
s.SignalDisconnect()
433-
} else if err != nil && !errors.Is(err, io.ErrClosedPipe) {
434-
s.logger.Debug(context.Background(), "error reading from pipe", slog.Error(err))
435-
s.SignalDisconnect()
430+
431+
// Handle different error types appropriately
432+
if err != nil {
433+
// Check for fatal errors that should terminate the goroutine
434+
if xerrors.Is(err, io.ErrClosedPipe) || xerrors.Is(err, backedpipe.ErrPipeClosed) {
435+
// The pipe itself is closed, we're done
436+
s.logger.Debug(context.Background(), "pipe closed, exiting copy goroutine", slog.Error(err))
437+
s.SignalDisconnect()
438+
return
439+
}
440+
441+
// Log various error types with appropriate context
442+
switch {
443+
case errors.Is(err, io.EOF):
444+
s.logger.Debug(context.Background(), "got EOF from pipe")
445+
s.SignalDisconnect()
446+
case xerrors.Is(err, backedpipe.ErrReconnectFailed):
447+
s.logger.Debug(context.Background(), "reconnect failed, pipe will retry", slog.Error(err))
448+
s.SignalDisconnect()
449+
case xerrors.Is(err, backedpipe.ErrReconnectionInProgress):
450+
s.logger.Debug(context.Background(), "reconnection in progress", slog.Error(err))
451+
// Don't signal disconnect - reconnection is already happening
452+
case xerrors.Is(err, backedpipe.ErrInvalidSequenceNumber):
453+
s.logger.Warn(context.Background(), "sequence number mismatch during reconnect", slog.Error(err))
454+
s.SignalDisconnect()
455+
default:
456+
s.logger.Debug(context.Background(), "error reading from pipe", slog.Error(err))
457+
s.SignalDisconnect()
458+
}
436459
}
437460

438461
if n > 0 {
@@ -447,14 +470,9 @@ func (s *Stream) startCopyingLocked() {
447470
}
448471

449472
if err != nil {
450-
// Check if this is a fatal error
451-
if xerrors.Is(err, io.ErrClosedPipe) {
452-
// The pipe itself is closed, we're done
453-
s.logger.Debug(context.Background(), "pipe closed, exiting copy goroutine")
454-
s.SignalDisconnect()
455-
return
456-
}
457-
// Any other error (including EOF) is handled by BackedPipe; continue
473+
// For non-fatal errors, BackedPipe handles reconnection internally
474+
// We continue the loop to keep reading after reconnection
475+
continue
458476
}
459477
}
460478
}()

0 commit comments

Comments
 (0)