Skip to content

Commit fd99a7f

Browse files
committed
chore: added immortal streams, manager and agent API integration
1 parent 9cafe05 commit fd99a7f

File tree

10 files changed

+3001
-0
lines changed

10 files changed

+3001
-0
lines changed

agent/agent.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ import (
4141
"github.com/coder/coder/v2/agent/agentexec"
4242
"github.com/coder/coder/v2/agent/agentscripts"
4343
"github.com/coder/coder/v2/agent/agentssh"
44+
"github.com/coder/coder/v2/agent/immortalstreams"
4445
"github.com/coder/coder/v2/agent/proto"
4546
"github.com/coder/coder/v2/agent/proto/resourcesmonitor"
4647
"github.com/coder/coder/v2/agent/reconnectingpty"
@@ -280,6 +281,9 @@ type agent struct {
280281
devcontainers bool
281282
containerAPIOptions []agentcontainers.Option
282283
containerAPI *agentcontainers.API
284+
285+
// Immortal streams
286+
immortalStreamsManager *immortalstreams.Manager
283287
}
284288

285289
func (a *agent) TailnetConn() *tailnet.Conn {
@@ -347,6 +351,9 @@ func (a *agent) init() {
347351

348352
a.containerAPI = agentcontainers.NewAPI(a.logger.Named("containers"), containerAPIOpts...)
349353

354+
// Initialize immortal streams manager
355+
a.immortalStreamsManager = immortalstreams.New(a.logger.Named("immortal-streams"), &net.Dialer{})
356+
350357
a.reconnectingPTYServer = reconnectingpty.NewServer(
351358
a.logger.Named("reconnecting-pty"),
352359
a.sshServer,
@@ -1930,6 +1937,12 @@ func (a *agent) Close() error {
19301937
a.logger.Error(a.hardCtx, "container API close", slog.Error(err))
19311938
}
19321939

1940+
if a.immortalStreamsManager != nil {
1941+
if err := a.immortalStreamsManager.Close(); err != nil {
1942+
a.logger.Error(a.hardCtx, "immortal streams manager close", slog.Error(err))
1943+
}
1944+
}
1945+
19331946
// Wait for the graceful shutdown to complete, but don't wait forever so
19341947
// that we don't break user expectations.
19351948
go func() {

agent/api.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"github.com/go-chi/chi/v5"
99
"github.com/google/uuid"
1010

11+
"github.com/coder/coder/v2/agent/immortalstreams"
1112
"github.com/coder/coder/v2/coderd/httpapi"
1213
"github.com/coder/coder/v2/codersdk"
1314
)
@@ -66,6 +67,12 @@ func (a *agent) apiHandler() http.Handler {
6667
r.Get("/debug/manifest", a.HandleHTTPDebugManifest)
6768
r.Get("/debug/prometheus", promHandler.ServeHTTP)
6869

70+
// Mount immortal streams API
71+
if a.immortalStreamsManager != nil {
72+
immortalStreamsHandler := immortalstreams.NewHandler(a.logger, a.immortalStreamsManager)
73+
r.Mount("/api/v0/immortal-stream", immortalStreamsHandler.Routes())
74+
}
75+
6976
return r
7077
}
7178

agent/immortalstreams/handler.go

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
package immortalstreams
2+
3+
import (
4+
"context"
5+
"errors"
6+
"fmt"
7+
"net/http"
8+
"strconv"
9+
"strings"
10+
11+
"github.com/go-chi/chi/v5"
12+
"github.com/google/uuid"
13+
"golang.org/x/xerrors"
14+
15+
"cdr.dev/slog"
16+
"github.com/coder/coder/v2/coderd/httpapi"
17+
"github.com/coder/coder/v2/codersdk"
18+
"github.com/coder/websocket"
19+
)
20+
21+
// Handler handles immortal stream requests
22+
type Handler struct {
23+
logger slog.Logger
24+
manager *Manager
25+
}
26+
27+
// NewHandler creates a new immortal streams handler
28+
func NewHandler(logger slog.Logger, manager *Manager) *Handler {
29+
return &Handler{
30+
logger: logger,
31+
manager: manager,
32+
}
33+
}
34+
35+
// Routes registers the immortal streams routes
36+
func (h *Handler) Routes() chi.Router {
37+
r := chi.NewRouter()
38+
39+
r.Post("/", h.createStream)
40+
r.Get("/", h.listStreams)
41+
r.Route("/{streamID}", func(r chi.Router) {
42+
r.Use(h.streamMiddleware)
43+
r.Get("/", h.handleStreamRequest)
44+
r.Delete("/", h.deleteStream)
45+
})
46+
47+
return r
48+
}
49+
50+
// streamMiddleware validates and extracts the stream ID
51+
func (*Handler) streamMiddleware(next http.Handler) http.Handler {
52+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
53+
streamIDStr := chi.URLParam(r, "streamID")
54+
streamID, err := uuid.Parse(streamIDStr)
55+
if err != nil {
56+
httpapi.Write(r.Context(), w, http.StatusBadRequest, codersdk.Response{
57+
Message: "Invalid stream ID format",
58+
})
59+
return
60+
}
61+
62+
ctx := context.WithValue(r.Context(), streamIDKey{}, streamID)
63+
next.ServeHTTP(w, r.WithContext(ctx))
64+
})
65+
}
66+
67+
// createStream creates a new immortal stream
68+
func (h *Handler) createStream(w http.ResponseWriter, r *http.Request) {
69+
ctx := r.Context()
70+
71+
var req codersdk.CreateImmortalStreamRequest
72+
if !httpapi.Read(ctx, w, r, &req) {
73+
return
74+
}
75+
76+
stream, err := h.manager.CreateStream(ctx, req.TCPPort)
77+
if err != nil {
78+
switch {
79+
case errors.Is(err, ErrTooManyStreams):
80+
httpapi.Write(ctx, w, http.StatusServiceUnavailable, codersdk.Response{
81+
Message: "Too many Immortal Streams.",
82+
})
83+
return
84+
case errors.Is(err, ErrConnRefused):
85+
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
86+
Message: "The connection was refused.",
87+
})
88+
return
89+
default:
90+
httpapi.InternalServerError(w, err)
91+
return
92+
}
93+
}
94+
95+
httpapi.Write(ctx, w, http.StatusCreated, stream)
96+
}
97+
98+
// listStreams lists all immortal streams
99+
func (h *Handler) listStreams(w http.ResponseWriter, r *http.Request) {
100+
ctx := r.Context()
101+
streams := h.manager.ListStreams()
102+
httpapi.Write(ctx, w, http.StatusOK, streams)
103+
}
104+
105+
// handleStreamRequest handles GET requests for a specific stream and returns stream info or handles WebSocket upgrades
106+
func (h *Handler) handleStreamRequest(w http.ResponseWriter, r *http.Request) {
107+
ctx := r.Context()
108+
streamID := getStreamID(ctx)
109+
110+
// Check if this is a WebSocket upgrade request by looking for WebSocket headers
111+
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") {
112+
h.handleUpgrade(w, r)
113+
return
114+
}
115+
116+
// Otherwise, return stream info
117+
stream, ok := h.manager.GetStream(streamID)
118+
if !ok {
119+
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
120+
Message: "Stream not found",
121+
})
122+
return
123+
}
124+
125+
httpapi.Write(ctx, w, http.StatusOK, stream.ToAPI())
126+
}
127+
128+
// deleteStream deletes a stream
129+
func (h *Handler) deleteStream(w http.ResponseWriter, r *http.Request) {
130+
ctx := r.Context()
131+
streamID := getStreamID(ctx)
132+
133+
err := h.manager.DeleteStream(streamID)
134+
if err != nil {
135+
switch {
136+
case errors.Is(err, ErrStreamNotFound):
137+
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
138+
Message: "Stream not found",
139+
})
140+
return
141+
default:
142+
httpapi.InternalServerError(w, err)
143+
return
144+
}
145+
}
146+
147+
w.WriteHeader(http.StatusNoContent)
148+
}
149+
150+
// handleUpgrade handles WebSocket upgrade for immortal stream connections
151+
func (h *Handler) handleUpgrade(w http.ResponseWriter, r *http.Request) {
152+
ctx := r.Context()
153+
streamID := getStreamID(ctx)
154+
155+
// Get sequence numbers from headers
156+
readSeqNum, err := parseSequenceNumber(r.Header.Get(codersdk.HeaderImmortalStreamSequenceNum))
157+
if err != nil {
158+
httpapi.Write(ctx, w, http.StatusBadRequest, codersdk.Response{
159+
Message: fmt.Sprintf("Invalid sequence number: %v", err),
160+
})
161+
return
162+
}
163+
164+
// Check if stream exists
165+
_, ok := h.manager.GetStream(streamID)
166+
if !ok {
167+
httpapi.Write(ctx, w, http.StatusNotFound, codersdk.Response{
168+
Message: "Stream not found",
169+
})
170+
return
171+
}
172+
173+
// Upgrade to WebSocket
174+
conn, err := websocket.Accept(w, r, &websocket.AcceptOptions{
175+
CompressionMode: websocket.CompressionDisabled,
176+
})
177+
if err != nil {
178+
h.logger.Error(ctx, "failed to accept websocket", slog.Error(err))
179+
return
180+
}
181+
182+
// Create a context that we can cancel to clean up the connection
183+
connCtx, cancel := context.WithCancel(ctx)
184+
defer cancel()
185+
186+
// Ensure WebSocket is closed when this function returns
187+
defer func() {
188+
conn.Close(websocket.StatusNormalClosure, "connection closed")
189+
}()
190+
191+
// Create a WebSocket adapter
192+
wsConn := &wsConn{
193+
conn: conn,
194+
logger: h.logger,
195+
ctx: connCtx,
196+
cancel: cancel,
197+
}
198+
199+
// Handle the reconnection - this establishes the connection
200+
// BackedPipe only needs the reader sequence number for replay
201+
err = h.manager.HandleConnection(streamID, wsConn, readSeqNum)
202+
if err != nil {
203+
switch {
204+
case errors.Is(err, ErrStreamNotFound):
205+
conn.Close(websocket.StatusUnsupportedData, "Stream not found")
206+
return
207+
case errors.Is(err, ErrAlreadyConnected):
208+
conn.Close(websocket.StatusPolicyViolation, "Already connected")
209+
return
210+
default:
211+
h.logger.Error(ctx, "failed to handle connection", slog.Error(err))
212+
conn.Close(websocket.StatusInternalError, err.Error())
213+
return
214+
}
215+
}
216+
217+
// Keep the connection open until the context is canceled
218+
// The wsConn will handle connection closure through its Read/Write methods
219+
// When the connection is closed, the backing pipe will detect it and the context should be canceled
220+
<-connCtx.Done()
221+
}
222+
223+
// wsConn adapts a WebSocket connection to io.ReadWriteCloser
224+
type wsConn struct {
225+
conn *websocket.Conn
226+
logger slog.Logger
227+
ctx context.Context
228+
cancel context.CancelFunc
229+
}
230+
231+
func (c *wsConn) Read(p []byte) (n int, err error) {
232+
typ, data, err := c.conn.Read(c.ctx)
233+
if err != nil {
234+
// Cancel the context when read fails (connection closed)
235+
c.cancel()
236+
return 0, err
237+
}
238+
if typ != websocket.MessageBinary {
239+
return 0, xerrors.Errorf("unexpected message type: %v", typ)
240+
}
241+
n = copy(p, data)
242+
return n, nil
243+
}
244+
245+
func (c *wsConn) Write(p []byte) (n int, err error) {
246+
err = c.conn.Write(c.ctx, websocket.MessageBinary, p)
247+
if err != nil {
248+
// Cancel the context when write fails (connection closed)
249+
c.cancel()
250+
return 0, err
251+
}
252+
return len(p), nil
253+
}
254+
255+
func (c *wsConn) Close() error {
256+
c.cancel() // Cancel the context when explicitly closed
257+
return c.conn.Close(websocket.StatusNormalClosure, "")
258+
}
259+
260+
// parseSequenceNumber parses a sequence number from a string
261+
func parseSequenceNumber(s string) (uint64, error) {
262+
if s == "" {
263+
return 0, nil
264+
}
265+
return strconv.ParseUint(s, 10, 64)
266+
}
267+
268+
// getStreamID gets the stream ID from the context
269+
func getStreamID(ctx context.Context) uuid.UUID {
270+
id, _ := ctx.Value(streamIDKey{}).(uuid.UUID)
271+
return id
272+
}
273+
274+
type streamIDKey struct{}

0 commit comments

Comments
 (0)