diff --git a/src/Renci.SshNet/SshCommand.cs b/src/Renci.SshNet/SshCommand.cs index 6bfd57422..f900e843b 100644 --- a/src/Renci.SshNet/SshCommand.cs +++ b/src/Renci.SshNet/SshCommand.cs @@ -300,7 +300,18 @@ public Task ExecuteAsync(CancellationToken cancellationToken = default) if (cancellationToken.CanBeCanceled) { - _tokenRegistration = cancellationToken.Register(static cmd => ((SshCommand)cmd!).CancelAsync(), this); + _tokenRegistration = cancellationToken.Register(static cmd => + { + try + { + ((SshCommand)cmd!).CancelAsync(); + } + catch + { + // Swallow exceptions which would otherwise be unhandled. + } + }, + this); } return _tcs.Task; @@ -437,33 +448,31 @@ public void CancelAsync(bool forceKill = false, int millisecondsTimeout = 500) _cancellationRequested = true; Interlocked.MemoryBarrier(); // ensure fresh read in SetAsyncComplete (possibly unnecessary) - // Try to send the cancellation signal. - if (_channel?.SendSignalRequest(forceKill ? "KILL" : "TERM") is null) - { - // Command has completed (in the meantime since the last check). - return; - } - - // Having sent the "signal" message, we expect to receive "exit-signal" - // and then a close message. But since a server may not implement signals, - // we can't guarantee that, so we wait a short time for that to happen and - // if it doesn't, just complete the task ourselves to unblock waiters. - try { - if (_tcs.Task.Wait(millisecondsTimeout)) + // Try to send the cancellation signal. + if (_channel?.SendSignalRequest(forceKill ? "KILL" : "TERM") is null) { + // Command has completed (in the meantime since the last check). return; } + + // Having sent the "signal" message, we expect to receive "exit-signal" + // and then a close message. But since a server may not implement signals, + // we can't guarantee that, so we wait a short time for that to happen and + // if it doesn't, just complete the task ourselves to unblock waiters. + + _ = _tcs.Task.Wait(millisecondsTimeout); } catch (AggregateException) { - // We expect to be here if the server implements signals. + // We expect to be here from the call to Wait if the server implements signals. // But we don't want to propagate the exception on the task from here. - return; } - - SetAsyncComplete(); + finally + { + SetAsyncComplete(); + } } /// diff --git a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs index 46b5e5ddf..8f355f96b 100644 --- a/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs +++ b/test/Renci.SshNet.IntegrationTests/OldIntegrationTests/SshCommandTest.cs @@ -194,6 +194,25 @@ public async Task Test_ExecuteAsync_Timeout() } } + [TestMethod] + [Timeout(15000)] + public async Task Test_ExecuteAsync_Disconnect() + { + using (var client = new SshClient(SshServerHostName, SshServerPort, User.UserName, User.Password)) + { + client.Connect(); + using var cmd = client.CreateCommand("sleep 10s"); + cmd.CommandTimeout = TimeSpan.FromSeconds(2); + + Task executeTask = cmd.ExecuteAsync(); + + client.Disconnect(); + + // Waiting for timeout is not optimal here, but better than hanging indefinitely. + await Assert.ThrowsExceptionAsync(() => executeTask); + } + } + [TestMethod] public void Test_Execute_InvalidCommand() {