mapper: close stale map channels after send timeouts

When the batcher timed out sending to a node, it removed the channel from multiChannelNodeConn but left the old serveLongPoll goroutine running on that channel. That left a live stale session behind: it no longer received new updates, but it could still keep the stream open and block shutdown.

Close the pruned channel when stale-send cleanup removes it so the old map session exits after draining any buffered update.
This commit is contained in:
DM
2026-03-07 22:09:33 +03:00
committed by Kristoffer Dalby
parent b81d6c734d
commit 3daf45e88a
2 changed files with 214 additions and 8 deletions

View File

@@ -556,13 +556,32 @@ func (mc *multiChannelNodeConn) close() {
defer mc.mutex.Unlock()
for _, conn := range mc.connections {
// Mark as closed before closing the channel to prevent
// send on closed channel panics from concurrent workers
conn.closed.Store(true)
mc.closeConnection(conn)
}
}
// closeConnection closes connection channel at most once, even if multiple cleanup
// paths race to tear the same session down.
func (mc *multiChannelNodeConn) closeConnection(conn *connectionEntry) {
if conn.closed.CompareAndSwap(false, true) {
close(conn.c)
}
}
// removeConnectionAtIndexLocked removes the active connection at index.
// If closeChannel is true, it also closes that session's map-response channel.
// Caller must hold mc.mutex.
func (mc *multiChannelNodeConn) removeConnectionAtIndexLocked(i int, closeChannel bool) *connectionEntry {
conn := mc.connections[i]
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
if closeChannel {
mc.closeConnection(conn)
}
return conn
}
// addConnection adds a new connection.
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
mutexWaitStart := time.Now()
@@ -590,8 +609,7 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
for i, entry := range mc.connections {
if entry.c == c {
// Remove this connection
mc.connections = append(mc.connections[:i], mc.connections[i+1:]...)
mc.removeConnectionAtIndexLocked(i, false)
mc.log.Debug().Caller().Str(zf.Chan, fmt.Sprintf("%p", c)).
Int("remaining_connections", len(mc.connections)).
Msg("successfully removed connection")
@@ -673,10 +691,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
// Remove failed connections (in reverse order to maintain indices)
for i := len(failedConnections) - 1; i >= 0; i-- {
idx := failedConnections[i]
entry := mc.removeConnectionAtIndexLocked(idx, true)
mc.log.Debug().Caller().
Str(zf.ConnID, mc.connections[idx].id).
Msg("send: removing failed connection")
mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...)
Str(zf.ConnID, entry.id).
Msg("send: removed failed connection")
}
mc.updateCount.Add(1)

188
hscontrol/poll_test.go Normal file
View File

@@ -0,0 +1,188 @@
package hscontrol
import (
"context"
"net/http"
"sync"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/mapper"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
)
type delayedSuccessResponseWriter struct {
header http.Header
firstWriteDelay time.Duration
firstWriteStarted chan struct{}
firstWriteStartedOnce sync.Once
firstWriteFinished chan struct{}
firstWriteFinishedOnce sync.Once
mu sync.Mutex
writeCount int
}
func newDelayedSuccessResponseWriter(firstWriteDelay time.Duration) *delayedSuccessResponseWriter {
return &delayedSuccessResponseWriter{
header: make(http.Header),
firstWriteDelay: firstWriteDelay,
firstWriteStarted: make(chan struct{}),
firstWriteFinished: make(chan struct{}),
}
}
func (w *delayedSuccessResponseWriter) Header() http.Header {
return w.header
}
func (w *delayedSuccessResponseWriter) WriteHeader(int) {}
func (w *delayedSuccessResponseWriter) Write(data []byte) (int, error) {
w.mu.Lock()
w.writeCount++
writeCount := w.writeCount
w.mu.Unlock()
if writeCount == 1 {
// Only the first write is delayed. This simulates a transiently wedged map response:
// long enough to make the batcher time out future sends,
// but short enough that the old session can still recover if we leave it alive
w.firstWriteStartedOnce.Do(func() {
close(w.firstWriteStarted)
})
time.Sleep(w.firstWriteDelay)
w.firstWriteFinishedOnce.Do(func() {
close(w.firstWriteFinished)
})
}
return len(data), nil
}
func (w *delayedSuccessResponseWriter) Flush() {}
func (w *delayedSuccessResponseWriter) FirstWriteStarted() <-chan struct{} {
return w.firstWriteStarted
}
func (w *delayedSuccessResponseWriter) FirstWriteFinished() <-chan struct{} {
return w.firstWriteFinished
}
func (w *delayedSuccessResponseWriter) WriteCount() int {
w.mu.Lock()
defer w.mu.Unlock()
return w.writeCount
}
// Reproducer outline:
// 1. Start a real long-poll session for one node.
// 2. Make the first map write block briefly, so the session stops draining m.ch.
// 3. While that write is blocked, queue enough updates to fill the buffered
// session channel and make the next batcher send hit the stale-send timeout.
// 4. Let the blocked write recover. The stale session should still flush the
// update that was already buffered before its channel was pruned.
// 5. After that buffered update is drained, the stale session must exit instead
// of lingering as an orphaned serveLongPoll goroutine.
func TestTransientlyBlockedWriteDoesNotLeaveLiveStaleSession(t *testing.T) {
t.Parallel()
app := createTestApp(t)
user := app.state.CreateUserForTest("poll-stale-session-user")
createdNode := app.state.CreateRegisteredNodeForTest(user, "poll-stale-session-node")
require.NoError(t, app.state.UpdatePolicyManagerUsersForTest())
app.cfg.Tuning.BatchChangeDelay = 20 * time.Millisecond
app.cfg.Tuning.NodeMapSessionBufferedChanSize = 1
app.mapBatcher.Close()
require.NoError(t, app.state.Close())
reloadedState, err := state.NewState(app.cfg)
require.NoError(t, err)
app.state = reloadedState
app.mapBatcher = mapper.NewBatcherAndMapper(app.cfg, app.state)
app.mapBatcher.Start()
t.Cleanup(func() {
app.mapBatcher.Close()
require.NoError(t, app.state.Close())
})
nodeView, ok := app.state.GetNodeByID(createdNode.ID)
require.True(t, ok, "expected node to be present in NodeStore after reload")
require.True(t, nodeView.Valid(), "expected valid node view after reload")
node := nodeView.AsStruct()
ctx, cancel := context.WithCancel(context.Background())
writer := newDelayedSuccessResponseWriter(250 * time.Millisecond)
session := app.newMapSession(ctx, tailcfg.MapRequest{
Stream: true,
Version: tailcfg.CapabilityVersion(100),
}, writer, node)
serveDone := make(chan struct{})
go func() {
session.serveLongPoll()
close(serveDone)
}()
t.Cleanup(func() {
dummyCh := make(chan *tailcfg.MapResponse, 1)
_ = app.mapBatcher.AddNode(node.ID, dummyCh, tailcfg.CapabilityVersion(100))
cancel()
select {
case <-serveDone:
case <-time.After(2 * time.Second):
}
_ = app.mapBatcher.RemoveNode(node.ID, dummyCh)
})
select {
case <-writer.FirstWriteStarted():
case <-time.After(2 * time.Second):
t.Fatal("expected initial map write to start")
}
streamsClosed := make(chan struct{})
go func() {
app.clientStreamsOpen.Wait()
close(streamsClosed)
}()
// One update fills the buffered session channel while the first write is blocked.
// The second update then hits the 50ms stale-send timeout and the batcher prunes
// and closes that stale channel.
app.mapBatcher.AddWork(change.SelfUpdate(node.ID), change.SelfUpdate(node.ID))
select {
case <-writer.FirstWriteFinished():
case <-time.After(2 * time.Second):
t.Fatal("expected the blocked write to eventually complete")
}
assert.Eventually(t, func() bool {
return writer.WriteCount() >= 2
}, 2*time.Second, 20*time.Millisecond, "session should flush the update that was already buffered before the stale send")
assert.Eventually(t, func() bool {
select {
case <-streamsClosed:
return true
default:
return false
}
}, time.Second, 20*time.Millisecond, "after stale-send cleanup, the stale session should exit")
}