Skip to content

Commit 62d31a1

Browse files
committed
WIP
1 parent 1e15ee8 commit 62d31a1

File tree

6 files changed

+213
-105
lines changed

6 files changed

+213
-105
lines changed

agent/immortalstreams/backedpipe/backed_pipe.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,20 @@ func (bp *BackedPipe) reconnectLocked() error {
256256
bp.connGen++
257257
bp.state = connected
258258

259+
// Store the generation number before releasing the lock
260+
currentGen := bp.connGen
261+
262+
// Release the lock before calling SetGeneration to avoid deadlock
263+
// SetGeneration acquires its own mutex, and we don't want to hold
264+
// the BackedPipe mutex while waiting for component mutexes
265+
bp.mu.Unlock()
266+
259267
// Update the generation on reader and writer for error reporting
260-
bp.reader.SetGeneration(bp.connGen)
261-
bp.writer.SetGeneration(bp.connGen)
268+
bp.reader.SetGeneration(currentGen)
269+
bp.writer.SetGeneration(currentGen)
270+
271+
// Re-acquire the lock to maintain the function contract
272+
bp.mu.Lock()
262273

263274
return nil
264275
}

agent/immortalstreams/handler.go

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package immortalstreams
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"net/http"
78
"strconv"
@@ -74,20 +75,21 @@ func (h *Handler) createStream(w http.ResponseWriter, r *http.Request) {
7475

7576
stream, err := h.manager.CreateStream(ctx, req.TCPPort)
7677
if err != nil {
77-
if strings.Contains(err.Error(), "too many immortal streams") {
78+
switch {
79+
case errors.Is(err, ErrTooManyStreams):
7880
httpapi.Write(ctx, w, http.StatusServiceUnavailable, codersdk.Response{
7981
Message: "Too many Immortal Streams.",
8082
})
8183
return
82-
}
83-
if strings.Contains(err.Error(), "the connection was refused") {
84+
case errors.Is(err, ErrConnRefused):
8485
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
8586
Message: "The connection was refused.",
8687
})
8788
return
89+
default:
90+
httpapi.InternalServerError(w, err)
91+
return
8892
}
89-
httpapi.InternalServerError(w, err)
90-
return
9193
}
9294

9395
httpapi.Write(ctx, w, http.StatusCreated, stream)
@@ -130,14 +132,16 @@ func (h *Handler) deleteStream(w http.ResponseWriter, r *http.Request) {
130132

131133
err := h.manager.DeleteStream(streamID)
132134
if err != nil {
133-
if strings.Contains(err.Error(), "stream not found") {
135+
switch {
136+
case errors.Is(err, ErrStreamNotFound):
134137
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
135138
Message: "Stream not found",
136139
})
137140
return
141+
default:
142+
httpapi.InternalServerError(w, err)
143+
return
138144
}
139-
httpapi.InternalServerError(w, err)
140-
return
141145
}
142146

143147
w.WriteHeader(http.StatusNoContent)
@@ -196,9 +200,18 @@ func (h *Handler) handleUpgrade(w http.ResponseWriter, r *http.Request) {
196200
// BackedPipe only needs the reader sequence number for replay
197201
err = h.manager.HandleConnection(streamID, wsConn, readSeqNum)
198202
if err != nil {
199-
h.logger.Error(ctx, "failed to handle connection", slog.Error(err))
200-
conn.Close(websocket.StatusInternalError, err.Error())
201-
return
203+
switch {
204+
case errors.Is(err, ErrStreamNotFound):
205+
conn.Close(websocket.StatusUnsupportedData, "Stream not found")
206+
return
207+
case errors.Is(err, ErrAlreadyConnected):
208+
conn.Close(websocket.StatusPolicyViolation, "Already connected")
209+
return
210+
default:
211+
h.logger.Error(ctx, "failed to handle connection", slog.Error(err))
212+
conn.Close(websocket.StatusInternalError, err.Error())
213+
return
214+
}
202215
}
203216

204217
// Keep the connection open until the context is canceled

agent/immortalstreams/manager.go

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package immortalstreams
22

33
import (
44
"context"
5+
"errors"
56
"fmt"
67
"io"
78
"net"
89
"sync"
10+
"syscall"
911
"time"
1012

1113
"github.com/google/uuid"
@@ -16,6 +18,14 @@ import (
1618
"github.com/coder/coder/v2/codersdk"
1719
)
1820

21+
// Package-level sentinel errors
22+
var (
23+
ErrTooManyStreams = xerrors.New("too many streams")
24+
ErrStreamNotFound = xerrors.New("stream not found")
25+
ErrConnRefused = xerrors.New("connection refused")
26+
ErrAlreadyConnected = xerrors.New("already connected")
27+
)
28+
1929
const (
2030
// MaxStreams is the maximum number of immortal streams allowed per agent
2131
MaxStreams = 32
@@ -56,7 +66,7 @@ func (m *Manager) CreateStream(ctx context.Context, port int) (*codersdk.Immorta
5666
// Try to evict a disconnected stream
5767
evicted := m.evictOldestDisconnectedLocked()
5868
if !evicted {
59-
return nil, xerrors.New("too many immortal streams")
69+
return nil, ErrTooManyStreams
6070
}
6171
}
6272

@@ -65,7 +75,7 @@ func (m *Manager) CreateStream(ctx context.Context, port int) (*codersdk.Immorta
6575
conn, err := m.dialer.DialContext(ctx, "tcp", addr)
6676
if err != nil {
6777
if isConnectionRefused(err) {
68-
return nil, xerrors.Errorf("the connection was refused")
78+
return nil, ErrConnRefused
6979
}
7080
return nil, xerrors.Errorf("dial local service: %w", err)
7181
}
@@ -88,13 +98,9 @@ func (m *Manager) CreateStream(ctx context.Context, port int) (*codersdk.Immorta
8898

8999
m.streams[id] = stream
90100

91-
return &codersdk.ImmortalStream{
92-
ID: id,
93-
Name: name,
94-
TCPPort: port,
95-
CreatedAt: stream.createdAt,
96-
LastConnectionAt: stream.createdAt,
97-
}, nil
101+
// Return the API representation of the stream
102+
apiStream := stream.ToAPI()
103+
return &apiStream, nil
98104
}
99105

100106
// GetStream returns a stream by ID
@@ -124,7 +130,7 @@ func (m *Manager) DeleteStream(id uuid.UUID) error {
124130

125131
stream, ok := m.streams[id]
126132
if !ok {
127-
return xerrors.New("stream not found")
133+
return ErrStreamNotFound
128134
}
129135

130136
if err := stream.Close(); err != nil {
@@ -216,17 +222,28 @@ func (m *Manager) HandleConnection(id uuid.UUID, conn io.ReadWriteCloser, readSe
216222
m.mu.RUnlock()
217223

218224
if !ok {
219-
return xerrors.New("stream not found")
225+
return ErrStreamNotFound
220226
}
221227

222228
return stream.HandleReconnect(conn, readSeqNum)
223229
}
224230

225231
// isConnectionRefused checks if an error is a connection refused error
226232
func isConnectionRefused(err error) bool {
233+
// Check for syscall.ECONNREFUSED through error unwrapping
234+
var errno syscall.Errno
235+
if errors.As(err, &errno) && errno == syscall.ECONNREFUSED {
236+
return true
237+
}
238+
239+
// Fallback: check for net.OpError with "dial" operation
227240
var opErr *net.OpError
228-
if xerrors.As(err, &opErr) {
229-
return opErr.Op == "dial"
241+
if errors.As(err, &opErr) && opErr.Op == "dial" {
242+
// Check if the underlying error is ECONNREFUSED
243+
if errors.As(opErr.Err, &errno) && errno == syscall.ECONNREFUSED {
244+
return true
245+
}
230246
}
247+
231248
return false
232249
}

agent/immortalstreams/manager_test.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package immortalstreams_test
22

33
import (
44
"context"
5+
"errors"
56
"io"
67
"net"
78
"runtime"
@@ -73,7 +74,7 @@ func TestManager_CreateStream(t *testing.T) {
7374
// Use a port that's not listening
7475
_, err := manager.CreateStream(ctx, 65535)
7576
require.Error(t, err)
76-
require.Contains(t, err.Error(), "connection was refused")
77+
require.True(t, errors.Is(err, immortalstreams.ErrConnRefused))
7778
})
7879

7980
t.Run("MaxStreamsLimit", func(t *testing.T) {
@@ -145,7 +146,7 @@ func TestManager_CreateStream(t *testing.T) {
145146
// All streams should be connected, so creating another should fail
146147
_, err = manager.CreateStream(ctx, port)
147148
require.Error(t, err)
148-
require.Contains(t, err.Error(), "too many immortal streams")
149+
require.True(t, errors.Is(err, immortalstreams.ErrTooManyStreams))
149150

150151
// Disconnect one stream
151152
err = manager.DeleteStream(streams[0])
@@ -259,7 +260,7 @@ func TestManager_DeleteStream(t *testing.T) {
259260
// Deleting again should error
260261
err = manager.DeleteStream(stream.ID)
261262
require.Error(t, err)
262-
require.Contains(t, err.Error(), "stream not found")
263+
require.True(t, errors.Is(err, immortalstreams.ErrStreamNotFound))
263264
}
264265

265266
func TestManager_GetStream(t *testing.T) {

agent/immortalstreams/stream.go

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,13 @@ func (r *streamReconnector) Reconnect(ctx context.Context, writerSeqNum uint64)
110110
r.s.reconnectCond.Broadcast()
111111
r.s.mu.Unlock()
112112

113-
// Wait for response from HandleReconnect or context cancellation
113+
// Wait for response from HandleReconnect or context cancellation with timeout
114114
r.s.logger.Info(context.Background(), "reconnect function waiting for response")
115+
116+
// Add a timeout to prevent indefinite hanging
117+
timeout := time.NewTimer(30 * time.Second)
118+
defer timeout.Stop()
119+
115120
select {
116121
case resp := <-responseChan:
117122
r.s.logger.Info(context.Background(), "reconnect function got response",
@@ -129,6 +134,16 @@ func (r *streamReconnector) Reconnect(ctx context.Context, writerSeqNum uint64)
129134
// The stream's Close() method will handle cleanup
130135
r.s.logger.Info(context.Background(), "reconnect function shutdown signal received")
131136
return nil, 0, xerrors.New("stream is shutting down")
137+
case <-timeout.C:
138+
// Timeout occurred - clean up the pending request
139+
r.s.mu.Lock()
140+
if r.s.pendingReconnect != nil {
141+
r.s.pendingReconnect = nil
142+
r.s.handshakePending = false
143+
}
144+
r.s.mu.Unlock()
145+
r.s.logger.Info(context.Background(), "reconnect function timed out")
146+
return nil, 0, xerrors.New("timeout waiting for reconnection response")
132147
}
133148
}
134149

@@ -276,7 +291,7 @@ func (s *Stream) HandleReconnect(clientConn io.ReadWriteCloser, readSeqNum uint6
276291
if s.connected {
277292
s.mu.Unlock()
278293
s.logger.Debug(context.Background(), "another goroutine completed reconnection")
279-
return xerrors.New("stream is already connected")
294+
return ErrAlreadyConnected
280295
}
281296

282297
// Ensure a reconnect attempt is requested while we wait.

0 commit comments

Comments
 (0)