From e10a5cffbc39f7fc33e3dacbc795c94173f9a19c Mon Sep 17 00:00:00 2001 From: Ethan Dickson Date: Fri, 27 Jun 2025 06:50:44 +0000 Subject: [PATCH] chore: populate connectionlog count using a separate query --- coderd/database/dbauthz/dbauthz.go | 19 ++- coderd/database/dbauthz/dbauthz_test.go | 36 +++++ coderd/database/dbauthz/setup_test.go | 2 +- coderd/database/dbmetrics/querymetrics.go | 14 ++ coderd/database/dbmock/dbmock.go | 30 ++++ coderd/database/modelqueries.go | 48 ++++++ coderd/database/modelqueries_internal_test.go | 13 ++ coderd/database/querier.go | 1 + coderd/database/querier_test.go | 95 ++++++++++++ coderd/database/queries.sql.go | 144 ++++++++++++++++++ coderd/database/queries/connectionlogs.sql | 106 +++++++++++++ coderd/searchquery/search.go | 23 ++- coderd/searchquery/search_test.go | 4 +- enterprise/coderd/connectionlog.go | 22 ++- enterprise/coderd/connectionlog_test.go | 6 +- 15 files changed, 552 insertions(+), 11 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a1c758ce03415..8b616f34b8441 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1353,15 +1353,26 @@ func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLog if err == nil { return q.db.CountAuditLogs(ctx, arg) } - prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAuditLog.Type) if err != nil { return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err) } - return q.db.CountAuthorizedAuditLogs(ctx, arg, prep) } +func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) { + // Just like the actual query, shortcut if the user is an owner. + err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog) + if err == nil { + return q.db.CountConnectionLogs(ctx, arg) + } + prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceConnectionLog.Type) + if err != nil { + return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.CountAuthorizedConnectionLogs(ctx, arg, prep) +} + func (q *querier) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceWorkspace.All()); err != nil { return nil, err @@ -5392,3 +5403,7 @@ func (q *querier) CountAuthorizedAuditLogs(ctx context.Context, arg database.Cou func (q *querier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg database.GetConnectionLogsOffsetParams, _ rbac.PreparedAuthorized) ([]database.GetConnectionLogsOffsetRow, error) { return q.GetConnectionLogsOffset(ctx, arg) } + +func (q *querier) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, _ rbac.PreparedAuthorized) (int64, error) { + return q.CountConnectionLogs(ctx, arg) +} diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 5416f33e521ec..2ea27f7d92342 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -406,6 +406,42 @@ func (s *MethodTestSuite) TestConnectionLogs() { LimitOpt: 10, }, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead) })) + s.Run("CountConnectionLogs", s.Subtest(func(db database.Store, check *expects) { + ws := createWorkspace(s.T(), db) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + check.Args(database.CountConnectionLogsParams{}).Asserts( + rbac.ResourceConnectionLog, policy.ActionRead, + ).WithNotAuthorized("nil") + })) + s.Run("CountAuthorizedConnectionLogs", s.Subtest(func(db database.Store, check *expects) { + ws := createWorkspace(s.T(), db) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + _ = dbgen.ConnectionLog(s.T(), db, database.UpsertConnectionLogParams{ + Type: database.ConnectionTypeSsh, + WorkspaceID: ws.ID, + OrganizationID: ws.OrganizationID, + WorkspaceOwnerID: ws.OwnerID, + }) + check.Args(database.CountConnectionLogsParams{}, emptyPreparedAuthorized{}).Asserts( + rbac.ResourceConnectionLog, policy.ActionRead, + ) + })) } func (s *MethodTestSuite) TestFile() { diff --git a/coderd/database/dbauthz/setup_test.go b/coderd/database/dbauthz/setup_test.go index 23effafc632e0..d4dacb78a4d50 100644 --- a/coderd/database/dbauthz/setup_test.go +++ b/coderd/database/dbauthz/setup_test.go @@ -318,7 +318,7 @@ func hasEmptyResponse(values []reflect.Value) bool { } } - // Special case for int64, as it's the return type for count query. + // Special case for int64, as it's the return type for count queries. if r.Kind() == reflect.Int64 { if r.Int() == 0 { return true diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index e353a4688281d..a0090a1103279 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -194,6 +194,13 @@ func (m queryMetricsStore) CountAuditLogs(ctx context.Context, arg database.Coun return r0, r1 } +func (m queryMetricsStore) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) { + start := time.Now() + r0, r1 := m.s.CountConnectionLogs(ctx, arg) + m.queryLatencies.WithLabelValues("CountConnectionLogs").Observe(time.Since(start).Seconds()) + return r0, r1 +} + func (m queryMetricsStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) { start := time.Now() r0, r1 := m.s.CountInProgressPrebuilds(ctx) @@ -3413,3 +3420,10 @@ func (m queryMetricsStore) GetAuthorizedConnectionLogsOffset(ctx context.Context m.queryLatencies.WithLabelValues("GetAuthorizedConnectionLogsOffset").Observe(time.Since(start).Seconds()) return r0, r1 } + +func (m queryMetricsStore) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { + start := time.Now() + r0, r1 := m.s.CountAuthorizedConnectionLogs(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("CountAuthorizedConnectionLogs").Observe(time.Since(start).Seconds()) + return r0, r1 +} diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 14e5344325b9b..723c4f3687e81 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -278,6 +278,36 @@ func (mr *MockStoreMockRecorder) CountAuthorizedAuditLogs(ctx, arg, prepared any return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAuditLogs", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAuditLogs), ctx, arg, prepared) } +// CountAuthorizedConnectionLogs mocks base method. +func (m *MockStore) CountAuthorizedConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAuthorizedConnectionLogs", ctx, arg, prepared) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAuthorizedConnectionLogs indicates an expected call of CountAuthorizedConnectionLogs. +func (mr *MockStoreMockRecorder) CountAuthorizedConnectionLogs(ctx, arg, prepared any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountAuthorizedConnectionLogs), ctx, arg, prepared) +} + +// CountConnectionLogs mocks base method. +func (m *MockStore) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountConnectionLogs", ctx, arg) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountConnectionLogs indicates an expected call of CountConnectionLogs. +func (mr *MockStoreMockRecorder) CountConnectionLogs(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountConnectionLogs", reflect.TypeOf((*MockStore)(nil).CountConnectionLogs), ctx, arg) +} + // CountInProgressPrebuilds mocks base method. func (m *MockStore) CountInProgressPrebuilds(ctx context.Context) ([]database.CountInProgressPrebuildsRow, error) { m.ctrl.T.Helper() diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 193ac3daa46bf..6bb7483847a2e 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -614,6 +614,7 @@ func (q *sqlQuerier) CountAuthorizedAuditLogs(ctx context.Context, arg CountAudi type connectionLogQuerier interface { GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) + CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) } func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg GetConnectionLogsOffsetParams, prepared rbac.PreparedAuthorized) ([]GetConnectionLogsOffsetRow, error) { @@ -700,6 +701,53 @@ func (q *sqlQuerier) GetAuthorizedConnectionLogsOffset(ctx context.Context, arg return items, nil } +func (q *sqlQuerier) CountAuthorizedConnectionLogs(ctx context.Context, arg CountConnectionLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.ConnectionLogConverter(), + }) + if err != nil { + return 0, xerrors.Errorf("compile authorized filter: %w", err) + } + filtered, err := insertAuthorizedFilter(countConnectionLogs, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return 0, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: CountAuthorizedConnectionLogs :one\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.OrganizationID, + arg.WorkspaceOwner, + arg.WorkspaceOwnerID, + arg.WorkspaceOwnerEmail, + arg.Type, + arg.UserID, + arg.Username, + arg.UserEmail, + arg.ConnectedAfter, + arg.ConnectedBefore, + arg.WorkspaceID, + arg.ConnectionID, + arg.Status, + ) + if err != nil { + return 0, err + } + defer rows.Close() + var count int64 + for rows.Next() { + if err := rows.Scan(&count); err != nil { + return 0, err + } + } + if err := rows.Close(); err != nil { + return 0, err + } + if err := rows.Err(); err != nil { + return 0, err + } + return count, nil +} + func insertAuthorizedFilter(query string, replaceWith string) (string, error) { if !strings.Contains(query, authorizedQueryPlaceholder) { return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query") diff --git a/coderd/database/modelqueries_internal_test.go b/coderd/database/modelqueries_internal_test.go index 4f675a1b60785..275ed947a3e4c 100644 --- a/coderd/database/modelqueries_internal_test.go +++ b/coderd/database/modelqueries_internal_test.go @@ -76,6 +76,19 @@ func TestAuditLogsQueryConsistency(t *testing.T) { } } +// Same as TestAuditLogsQueryConsistency, but for connection logs. +func TestConnectionLogsQueryConsistency(t *testing.T) { + t.Parallel() + + getWhereClause := extractWhereClause(getConnectionLogsOffset) + require.NotEmpty(t, getWhereClause, "getConnectionLogsOffset query should have a WHERE clause") + + countWhereClause := extractWhereClause(countConnectionLogs) + require.NotEmpty(t, countWhereClause, "countConnectionLogs query should have a WHERE clause") + + require.Equal(t, getWhereClause, countWhereClause, "getConnectionLogsOffset and countConnectionLogs queries should have the same WHERE clause") +} + // extractWhereClause extracts the WHERE clause from a SQL query string func extractWhereClause(query string) string { // Find WHERE and get everything after it diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 8af37596cb5c6..72f511618838b 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -66,6 +66,7 @@ type sqlcQuerier interface { CleanTailnetLostPeers(ctx context.Context) error CleanTailnetTunnels(ctx context.Context) error CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) + CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) // CountInProgressPrebuilds returns the number of in-progress prebuilds, grouped by preset ID and transition. // Prebuild considered in-progress if it's in the "starting", "stopping", or "deleting" state. CountInProgressPrebuilds(ctx context.Context) ([]CountInProgressPrebuildsRow, error) diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index a3d48e46b4fe7..20b07450364af 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -2168,6 +2168,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { require.NoError(t, err) // Then: No logs returned require.Len(t, logs, 0, "no logs should be returned") + // And: The count matches the number of logs returned + count, err := authDb.CountConnectionLogs(memberCtx, database.CountConnectionLogsParams{}) + require.NoError(t, err) + require.EqualValues(t, len(logs), count) }) t.Run("SiteWideAuditor", func(t *testing.T) { @@ -2186,6 +2190,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { require.NoError(t, err) // Then: All logs are returned require.ElementsMatch(t, connectionOnlyIDs(allLogs), connectionOnlyIDs(logs)) + // And: The count matches the number of logs returned + count, err := authDb.CountConnectionLogs(siteAuditorCtx, database.CountConnectionLogsParams{}) + require.NoError(t, err) + require.EqualValues(t, len(logs), count) }) t.Run("SingleOrgAuditor", func(t *testing.T) { @@ -2205,6 +2213,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { require.NoError(t, err) // Then: Only the logs for the organization are returned require.ElementsMatch(t, orgConnectionLogs[orgID], connectionOnlyIDs(logs)) + // And: The count matches the number of logs returned + count, err := authDb.CountConnectionLogs(orgAuditCtx, database.CountConnectionLogsParams{}) + require.NoError(t, err) + require.EqualValues(t, len(logs), count) }) t.Run("TwoOrgAuditors", func(t *testing.T) { @@ -2225,6 +2237,10 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { require.NoError(t, err) // Then: All logs for both organizations are returned require.ElementsMatch(t, append(orgConnectionLogs[first], orgConnectionLogs[second]...), connectionOnlyIDs(logs)) + // And: The count matches the number of logs returned + count, err := authDb.CountConnectionLogs(multiOrgAuditCtx, database.CountConnectionLogsParams{}) + require.NoError(t, err) + require.EqualValues(t, len(logs), count) }) t.Run("ErroneousOrg", func(t *testing.T) { @@ -2243,9 +2259,71 @@ func TestGetAuthorizedConnectionLogsOffset(t *testing.T) { require.NoError(t, err) // Then: No logs are returned require.Len(t, logs, 0, "no logs should be returned") + // And: The count matches the number of logs returned + count, err := authDb.CountConnectionLogs(userCtx, database.CountConnectionLogsParams{}) + require.NoError(t, err) + require.EqualValues(t, len(logs), count) }) } +func TestCountConnectionLogs(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + + db, _ := dbtestutil.NewDB(t) + + orgA := dbfake.Organization(t, db).Do() + userA := dbgen.User(t, db, database.User{}) + tplA := dbgen.Template(t, db, database.Template{OrganizationID: orgA.Org.ID, CreatedBy: userA.ID}) + wsA := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userA.ID, OrganizationID: orgA.Org.ID, TemplateID: tplA.ID}) + + orgB := dbfake.Organization(t, db).Do() + userB := dbgen.User(t, db, database.User{}) + tplB := dbgen.Template(t, db, database.Template{OrganizationID: orgB.Org.ID, CreatedBy: userB.ID}) + wsB := dbgen.Workspace(t, db, database.WorkspaceTable{OwnerID: userB.ID, OrganizationID: orgB.Org.ID, TemplateID: tplB.ID}) + + // Create logs for two different orgs. + for i := 0; i < 20; i++ { + dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{ + OrganizationID: wsA.OrganizationID, + WorkspaceOwnerID: wsA.OwnerID, + WorkspaceID: wsA.ID, + Type: database.ConnectionTypeSsh, + }) + } + for i := 0; i < 10; i++ { + dbgen.ConnectionLog(t, db, database.UpsertConnectionLogParams{ + OrganizationID: wsB.OrganizationID, + WorkspaceOwnerID: wsB.OwnerID, + WorkspaceID: wsB.ID, + Type: database.ConnectionTypeSsh, + }) + } + + // Count with a filter for orgA. + countParams := database.CountConnectionLogsParams{ + OrganizationID: orgA.Org.ID, + } + totalCount, err := db.CountConnectionLogs(ctx, countParams) + require.NoError(t, err) + require.Equal(t, int64(20), totalCount) + + // Get a paginated result for the same filter. + getParams := database.GetConnectionLogsOffsetParams{ + OrganizationID: orgA.Org.ID, + LimitOpt: 5, + OffsetOpt: 10, + } + logs, err := db.GetConnectionLogsOffset(ctx, getParams) + require.NoError(t, err) + require.Len(t, logs, 5) + + // The count with the filter should remain the same, independent of pagination. + countAfterGet, err := db.CountConnectionLogs(ctx, countParams) + require.NoError(t, err) + require.Equal(t, int64(20), countAfterGet) +} + func TestConnectionLogsOffsetFilters(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) @@ -2484,7 +2562,24 @@ func TestConnectionLogsOffsetFilters(t *testing.T) { t.Parallel() logs, err := db.GetConnectionLogsOffset(ctx, tc.params) require.NoError(t, err) + count, err := db.CountConnectionLogs(ctx, database.CountConnectionLogsParams{ + OrganizationID: tc.params.OrganizationID, + WorkspaceOwner: tc.params.WorkspaceOwner, + Type: tc.params.Type, + UserID: tc.params.UserID, + Username: tc.params.Username, + UserEmail: tc.params.UserEmail, + ConnectedAfter: tc.params.ConnectedAfter, + ConnectedBefore: tc.params.ConnectedBefore, + WorkspaceID: tc.params.WorkspaceID, + ConnectionID: tc.params.ConnectionID, + Status: tc.params.Status, + WorkspaceOwnerID: tc.params.WorkspaceOwnerID, + WorkspaceOwnerEmail: tc.params.WorkspaceOwnerEmail, + }) + require.NoError(t, err) require.ElementsMatch(t, tc.expectedLogIDs, connectionOnlyIDs(logs)) + require.Equal(t, len(tc.expectedLogIDs), int(count), "CountConnectionLogs should match the number of returned logs (no offset or limit)") }) } } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index cef983eb0f1b9..676ce75621ded 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -880,6 +880,150 @@ func (q *sqlQuerier) InsertAuditLog(ctx context.Context, arg InsertAuditLogParam return i, err } +const countConnectionLogs = `-- name: CountConnectionLogs :one +SELECT + COUNT(*) AS count +FROM + connection_logs +JOIN users AS workspace_owner ON + connection_logs.workspace_owner_id = workspace_owner.id +LEFT JOIN users ON + connection_logs.user_id = users.id +JOIN organizations ON + connection_logs.organization_id = organizations.id +WHERE + -- Filter organization_id + CASE + WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.organization_id = $1 + ELSE true + END + -- Filter by workspace owner username + AND CASE + WHEN $2 :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE lower(username) = lower($2) AND deleted = false + ) + ELSE true + END + -- Filter by workspace_owner_id + AND CASE + WHEN $3 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + workspace_owner_id = $3 + ELSE true + END + -- Filter by workspace_owner_email + AND CASE + WHEN $4 :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE email = $4 AND deleted = false + ) + ELSE true + END + -- Filter by type + AND CASE + WHEN $5 :: text != '' THEN + type = $5 :: connection_type + ELSE true + END + -- Filter by user_id + AND CASE + WHEN $6 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_id = $6 + ELSE true + END + -- Filter by username + AND CASE + WHEN $7 :: text != '' THEN + user_id = ( + SELECT id FROM users + WHERE lower(username) = lower($7) AND deleted = false + ) + ELSE true + END + -- Filter by user_email + AND CASE + WHEN $8 :: text != '' THEN + users.email = $8 + ELSE true + END + -- Filter by connected_after + AND CASE + WHEN $9 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time >= $9 + ELSE true + END + -- Filter by connected_before + AND CASE + WHEN $10 :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time <= $10 + ELSE true + END + -- Filter by workspace_id + AND CASE + WHEN $11 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.workspace_id = $11 + ELSE true + END + -- Filter by connection_id + AND CASE + WHEN $12 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.connection_id = $12 + ELSE true + END + -- Filter by whether the session has a disconnect_time + AND CASE + WHEN $13 :: text != '' THEN + (($13 = 'ongoing' AND disconnect_time IS NULL) OR + ($13 = 'completed' AND disconnect_time IS NOT NULL)) AND + -- Exclude web events, since we don't know their close time. + "type" NOT IN ('workspace_app', 'port_forwarding') + ELSE true + END + -- Authorize Filter clause will be injected below in + -- CountAuthorizedConnectionLogs + -- @authorize_filter +` + +type CountConnectionLogsParams struct { + OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` + WorkspaceOwner string `db:"workspace_owner" json:"workspace_owner"` + WorkspaceOwnerID uuid.UUID `db:"workspace_owner_id" json:"workspace_owner_id"` + WorkspaceOwnerEmail string `db:"workspace_owner_email" json:"workspace_owner_email"` + Type string `db:"type" json:"type"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + Username string `db:"username" json:"username"` + UserEmail string `db:"user_email" json:"user_email"` + ConnectedAfter time.Time `db:"connected_after" json:"connected_after"` + ConnectedBefore time.Time `db:"connected_before" json:"connected_before"` + WorkspaceID uuid.UUID `db:"workspace_id" json:"workspace_id"` + ConnectionID uuid.UUID `db:"connection_id" json:"connection_id"` + Status string `db:"status" json:"status"` +} + +func (q *sqlQuerier) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countConnectionLogs, + arg.OrganizationID, + arg.WorkspaceOwner, + arg.WorkspaceOwnerID, + arg.WorkspaceOwnerEmail, + arg.Type, + arg.UserID, + arg.Username, + arg.UserEmail, + arg.ConnectedAfter, + arg.ConnectedBefore, + arg.WorkspaceID, + arg.ConnectionID, + arg.Status, + ) + var count int64 + err := row.Scan(&count) + return count, err +} + const getConnectionLogsOffset = `-- name: GetConnectionLogsOffset :many SELECT connection_logs.id, connection_logs.connect_time, connection_logs.organization_id, connection_logs.workspace_owner_id, connection_logs.workspace_id, connection_logs.workspace_name, connection_logs.agent_name, connection_logs.type, connection_logs.ip, connection_logs.code, connection_logs.user_agent, connection_logs.user_id, connection_logs.slug_or_port, connection_logs.connection_id, connection_logs.disconnect_time, connection_logs.disconnect_reason, diff --git a/coderd/database/queries/connectionlogs.sql b/coderd/database/queries/connectionlogs.sql index e3f231a6b738e..eb2d1b0cb171a 100644 --- a/coderd/database/queries/connectionlogs.sql +++ b/coderd/database/queries/connectionlogs.sql @@ -132,6 +132,112 @@ LIMIT OFFSET @offset_opt; +-- name: CountConnectionLogs :one +SELECT + COUNT(*) AS count +FROM + connection_logs +JOIN users AS workspace_owner ON + connection_logs.workspace_owner_id = workspace_owner.id +LEFT JOIN users ON + connection_logs.user_id = users.id +JOIN organizations ON + connection_logs.organization_id = organizations.id +WHERE + -- Filter organization_id + CASE + WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.organization_id = @organization_id + ELSE true + END + -- Filter by workspace owner username + AND CASE + WHEN @workspace_owner :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE lower(username) = lower(@workspace_owner) AND deleted = false + ) + ELSE true + END + -- Filter by workspace_owner_id + AND CASE + WHEN @workspace_owner_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + workspace_owner_id = @workspace_owner_id + ELSE true + END + -- Filter by workspace_owner_email + AND CASE + WHEN @workspace_owner_email :: text != '' THEN + workspace_owner_id = ( + SELECT id FROM users + WHERE email = @workspace_owner_email AND deleted = false + ) + ELSE true + END + -- Filter by type + AND CASE + WHEN @type :: text != '' THEN + type = @type :: connection_type + ELSE true + END + -- Filter by user_id + AND CASE + WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + user_id = @user_id + ELSE true + END + -- Filter by username + AND CASE + WHEN @username :: text != '' THEN + user_id = ( + SELECT id FROM users + WHERE lower(username) = lower(@username) AND deleted = false + ) + ELSE true + END + -- Filter by user_email + AND CASE + WHEN @user_email :: text != '' THEN + users.email = @user_email + ELSE true + END + -- Filter by connected_after + AND CASE + WHEN @connected_after :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time >= @connected_after + ELSE true + END + -- Filter by connected_before + AND CASE + WHEN @connected_before :: timestamp with time zone != '0001-01-01 00:00:00Z' THEN + connect_time <= @connected_before + ELSE true + END + -- Filter by workspace_id + AND CASE + WHEN @workspace_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.workspace_id = @workspace_id + ELSE true + END + -- Filter by connection_id + AND CASE + WHEN @connection_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN + connection_logs.connection_id = @connection_id + ELSE true + END + -- Filter by whether the session has a disconnect_time + AND CASE + WHEN @status :: text != '' THEN + ((@status = 'ongoing' AND disconnect_time IS NULL) OR + (@status = 'completed' AND disconnect_time IS NOT NULL)) AND + -- Exclude web events, since we don't know their close time. + "type" NOT IN ('workspace_app', 'port_forwarding') + ELSE true + END + -- Authorize Filter clause will be injected below in + -- CountAuthorizedConnectionLogs + -- @authorize_filter +; -- name: UpsertConnectionLog :one INSERT INTO connection_logs ( diff --git a/coderd/searchquery/search.go b/coderd/searchquery/search.go index c17b3db77bdc5..d35f3c94b5ff7 100644 --- a/coderd/searchquery/search.go +++ b/coderd/searchquery/search.go @@ -86,7 +86,7 @@ func AuditLogs(ctx context.Context, db database.Store, query string) (database.G return filter, countFilter, parser.Errors } -func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey database.APIKey) (database.GetConnectionLogsOffsetParams, []codersdk.ValidationError) { +func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey database.APIKey) (database.GetConnectionLogsOffsetParams, database.CountConnectionLogsParams, []codersdk.ValidationError) { // Always lowercase for all searches. query = strings.ToLower(query) values, errors := searchTerms(query, func(term string, values url.Values) error { @@ -94,7 +94,8 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey return nil }) if len(errors) > 0 { - return database.GetConnectionLogsOffsetParams{}, errors + // nolint:exhaustruct // We don't need to initialize these structs because we return an error. + return database.GetConnectionLogsOffsetParams{}, database.CountConnectionLogsParams{}, errors } parser := httpapi.NewQueryParamParser() @@ -122,8 +123,24 @@ func ConnectionLogs(ctx context.Context, db database.Store, query string, apiKey filter.WorkspaceOwner = "" } + // This MUST be kept in sync with the above + countFilter := database.CountConnectionLogsParams{ + OrganizationID: filter.OrganizationID, + WorkspaceOwner: filter.WorkspaceOwner, + WorkspaceOwnerID: filter.WorkspaceOwnerID, + WorkspaceOwnerEmail: filter.WorkspaceOwnerEmail, + Type: filter.Type, + UserID: filter.UserID, + Username: filter.Username, + UserEmail: filter.UserEmail, + ConnectedAfter: filter.ConnectedAfter, + ConnectedBefore: filter.ConnectedBefore, + WorkspaceID: filter.WorkspaceID, + ConnectionID: filter.ConnectionID, + Status: filter.Status, + } parser.ErrorExcessParams(values) - return filter, parser.Errors + return filter, countFilter, parser.Errors } func Users(query string) (database.GetUsersParams, []codersdk.ValidationError) { diff --git a/coderd/searchquery/search_test.go b/coderd/searchquery/search_test.go index c251a4cd5bd90..4744b57edff4a 100644 --- a/coderd/searchquery/search_test.go +++ b/coderd/searchquery/search_test.go @@ -435,7 +435,7 @@ func TestSearchConnectionLogs(t *testing.T) { `connected_before:"2023-01-16T12:00:00+12:00" workspace_id:%s connection_id:%s status:ongoing`, workspaceID.String(), connectionID.String()) - values, errs := searchquery.ConnectionLogs(context.Background(), db, query, database.APIKey{}) + values, _, errs := searchquery.ConnectionLogs(context.Background(), db, query, database.APIKey{}) require.Len(t, errs, 0) expected := database.GetConnectionLogsOffsetParams{ @@ -462,7 +462,7 @@ func TestSearchConnectionLogs(t *testing.T) { db, _ := dbtestutil.NewDB(t) query := `username:me workspace_owner:me` - values, errs := searchquery.ConnectionLogs(context.Background(), db, query, database.APIKey{UserID: userID}) + values, _, errs := searchquery.ConnectionLogs(context.Background(), db, query, database.APIKey{UserID: userID}) require.Len(t, errs, 0) expected := database.GetConnectionLogsOffsetParams{ diff --git a/enterprise/coderd/connectionlog.go b/enterprise/coderd/connectionlog.go index 75413b82708fb..21f0420f0652d 100644 --- a/enterprise/coderd/connectionlog.go +++ b/enterprise/coderd/connectionlog.go @@ -36,7 +36,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) { } queryStr := r.URL.Query().Get("q") - filter, errs := searchquery.ConnectionLogs(ctx, api.Database, queryStr, apiKey) + filter, countFilter, errs := searchquery.ConnectionLogs(ctx, api.Database, queryStr, apiKey) if len(errs) > 0 { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ Message: "Invalid connection search query.", @@ -49,6 +49,24 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) { // #nosec G115 - Safe conversion as pagination limit is expected to be within int32 range filter.LimitOpt = int32(page.Limit) + count, err := api.Database.CountConnectionLogs(ctx, countFilter) + if dbauthz.IsNotAuthorizedError(err) { + httpapi.Forbidden(rw) + return + } + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + + if count == 0 { + httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{ + ConnectionLogs: []codersdk.ConnectionLog{}, + Count: 0, + }) + return + } + dblogs, err := api.Database.GetConnectionLogsOffset(ctx, filter) if dbauthz.IsNotAuthorizedError(err) { httpapi.Forbidden(rw) @@ -61,7 +79,7 @@ func (api *API) connectionLogs(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, codersdk.ConnectionLogResponse{ ConnectionLogs: convertConnectionLogs(dblogs), - Count: 0, // TODO(ethanndickson): Set count + Count: count, }) } diff --git a/enterprise/coderd/connectionlog_test.go b/enterprise/coderd/connectionlog_test.go index b94b2449f37c4..59ff1b780e7b6 100644 --- a/enterprise/coderd/connectionlog_test.go +++ b/enterprise/coderd/connectionlog_test.go @@ -65,6 +65,7 @@ func TestConnectionLogs(t *testing.T) { require.NoError(t, err) require.Len(t, logs.ConnectionLogs, 1) + require.EqualValues(t, 1, logs.Count) require.Equal(t, codersdk.ConnectionTypeSSH, logs.ConnectionLogs[0].Type) }) @@ -84,7 +85,7 @@ func TestConnectionLogs(t *testing.T) { logs, err := client.ConnectionLogs(ctx, codersdk.ConnectionLogsRequest{}) require.NoError(t, err) - + require.EqualValues(t, 0, logs.Count) require.Len(t, logs.ConnectionLogs, 0) }) @@ -133,6 +134,7 @@ func TestConnectionLogs(t *testing.T) { require.NoError(t, err) require.Len(t, logs.ConnectionLogs, 1) + require.EqualValues(t, 1, logs.Count) require.Equal(t, ws.OrganizationID, logs.ConnectionLogs[0].Organization.ID) }) @@ -169,6 +171,7 @@ func TestConnectionLogs(t *testing.T) { require.NoError(t, err) require.Len(t, logs.ConnectionLogs, 1) + require.EqualValues(t, 1, logs.Count) require.NotNil(t, logs.ConnectionLogs[0].WebInfo) require.Equal(t, clog.SlugOrPort.String, logs.ConnectionLogs[0].WebInfo.SlugOrPort) require.Equal(t, clog.UserAgent.String, logs.ConnectionLogs[0].WebInfo.UserAgent) @@ -241,6 +244,7 @@ func TestConnectionLogs(t *testing.T) { require.NoError(t, err) require.Len(t, logs.ConnectionLogs, 1) + require.EqualValues(t, 1, logs.Count) require.NotNil(t, logs.ConnectionLogs[0].SSHInfo) require.Nil(t, logs.ConnectionLogs[0].WebInfo) require.Equal(t, codersdk.ConnectionTypeSSH, logs.ConnectionLogs[0].Type)