diff --git a/src/NLU.DevOps.Luis.Shared/ILuisTrainClient.cs b/src/NLU.DevOps.Luis.Shared/ILuisTrainClient.cs index b56a5c8..953c5d5 100644 --- a/src/NLU.DevOps.Luis.Shared/ILuisTrainClient.cs +++ b/src/NLU.DevOps.Luis.Shared/ILuisTrainClient.cs @@ -46,7 +46,7 @@ public interface ILuisTrainClient : IDisposable /// LUIS app ID. /// LUIS version ID. /// Cancellation token. - Task> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken); + Task>> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken); /// /// Imports the LUIS app version. diff --git a/src/NLU.DevOps.Luis.Shared/LuisNLUTrainClient.cs b/src/NLU.DevOps.Luis.Shared/LuisNLUTrainClient.cs index 3c5fad1..41ed03b 100644 --- a/src/NLU.DevOps.Luis.Shared/LuisNLUTrainClient.cs +++ b/src/NLU.DevOps.Luis.Shared/LuisNLUTrainClient.cs @@ -157,14 +157,6 @@ public void Dispose() this.LuisClient.Dispose(); } - private static bool IsTransientStatusCode(HttpStatusCode statusCode) - { - return statusCode == HttpStatusCode.TooManyRequests - || (statusCode >= HttpStatusCode.InternalServerError - && statusCode != HttpStatusCode.HttpVersionNotSupported - && statusCode != HttpStatusCode.NotImplemented); - } - private LuisApp CreateLuisApp(IEnumerable utterances) { var luisApp = this.CreateLuisAppTemplate(); @@ -216,36 +208,31 @@ private async Task PollTrainingStatusAsync(CancellationToken cancellationToken) { while (true) { - try - { - var trainingStatus = await this.LuisClient.GetTrainingStatusAsync(this.LuisAppId, this.LuisConfiguration.VersionId, cancellationToken).ConfigureAwait(false); - var inProgress = trainingStatus - .Select(modelInfo => modelInfo.Details.Status) - .Any(status => status == "InProgress" || status == "Queued"); + var trainingStatus = await Retry.With(cancellationToken).OnTransientErrorResponseAsync(() => + this.LuisClient.GetTrainingStatusAsync(this.LuisAppId, this.LuisConfiguration.VersionId, cancellationToken)) + .ConfigureAwait(false); - if (!inProgress) - { - if (trainingStatus.Any(modelInfo => modelInfo.Details.Status == "Fail")) - { - var failureReasons = trainingStatus - .Where(modelInfo => modelInfo.Details.Status == "Fail") - .Select(modelInfo => $"- {modelInfo.Details.FailureReason}"); + var inProgress = trainingStatus.Value + .Select(modelInfo => modelInfo.Details.Status) + .Any(status => status == "InProgress" || status == "Queued"); - throw new InvalidOperationException($"Failure occurred while training LUIS model:\n{string.Join('\n', failureReasons)}"); - } + if (!inProgress) + { + if (trainingStatus.Value.Any(modelInfo => modelInfo.Details.Status == "Fail")) + { + var failureReasons = trainingStatus.Value + .Where(modelInfo => modelInfo.Details.Status == "Fail") + .Select(modelInfo => $"- {modelInfo.Details.FailureReason}"); - break; + throw new InvalidOperationException($"Failure occurred while training LUIS model:\n{string.Join('\n', failureReasons)}"); } - Logger.LogTrace($"Training jobs not complete. Polling again."); - await Task.Delay(TrainStatusDelay, cancellationToken).ConfigureAwait(false); - } - catch (ErrorResponseException ex) - when (IsTransientStatusCode(ex.Response.StatusCode)) - { - Logger.LogTrace("Received HTTP 429 result from LUIS. Retrying."); - await Task.Delay(TrainStatusDelay, cancellationToken).ConfigureAwait(false); + break; } + + Logger.LogTrace($"Training jobs not complete. Polling again."); + var delay = Retry.GetRetryAfterDelay(trainingStatus.RetryAfter, TrainStatusDelay); + await Task.Delay(delay, cancellationToken).ConfigureAwait(false); } } } diff --git a/src/NLU.DevOps.Luis.Shared/LuisTrainClient.cs b/src/NLU.DevOps.Luis.Shared/LuisTrainClient.cs index 79077d9..d17b33f 100644 --- a/src/NLU.DevOps.Luis.Shared/LuisTrainClient.cs +++ b/src/NLU.DevOps.Luis.Shared/LuisTrainClient.cs @@ -61,9 +61,10 @@ public Task DeleteVersionAsync(string appId, string versionId, CancellationToken return this.AuthoringClient.Versions.DeleteAsync(Guid.Parse(appId), versionId, cancellationToken); } - public Task> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken) + public async Task>> GetTrainingStatusAsync(string appId, string versionId, CancellationToken cancellationToken) { - return this.AuthoringClient.Train.GetStatusAsync(Guid.Parse(appId), versionId, cancellationToken); + var operationResponse = await this.AuthoringClient.Train.GetStatusWithHttpMessagesAsync(Guid.Parse(appId), versionId, cancellationToken: cancellationToken).ConfigureAwait(false); + return OperationResponse.Create(operationResponse.Body, operationResponse.Response); } public Task ImportVersionAsync(string appId, string versionId, LuisApp luisApp, CancellationToken cancellationToken) diff --git a/src/NLU.DevOps.Luis.Shared/NLU.DevOps.Luis.Shared.projitems b/src/NLU.DevOps.Luis.Shared/NLU.DevOps.Luis.Shared.projitems index f499c28..f025647 100644 --- a/src/NLU.DevOps.Luis.Shared/NLU.DevOps.Luis.Shared.projitems +++ b/src/NLU.DevOps.Luis.Shared/NLU.DevOps.Luis.Shared.projitems @@ -18,5 +18,8 @@ + + + diff --git a/src/NLU.DevOps.Luis.Shared/OperationResponse.Generic.cs b/src/NLU.DevOps.Luis.Shared/OperationResponse.Generic.cs new file mode 100644 index 0000000..a835228 --- /dev/null +++ b/src/NLU.DevOps.Luis.Shared/OperationResponse.Generic.cs @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace NLU.DevOps.Luis +{ + /// + /// Information about the batch test evaluation operation status. + /// + /// Type of response value. + public class OperationResponse + { + internal OperationResponse(T value, string retryAfter) + { + this.Value = value; + this.RetryAfter = retryAfter; + } + + /// + /// Gets the response value. + /// + public T Value { get; } + + /// + /// Gets the HTTP 'Retry-After' header. + /// + public string RetryAfter { get; } + } +} diff --git a/src/NLU.DevOps.Luis.Shared/OperationResponse.cs b/src/NLU.DevOps.Luis.Shared/OperationResponse.cs new file mode 100644 index 0000000..4d90bbe --- /dev/null +++ b/src/NLU.DevOps.Luis.Shared/OperationResponse.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace NLU.DevOps.Luis +{ + using System.Linq; + using System.Net.Http; + + /// + /// Factory methods for . + /// + public static class OperationResponse + { + /// + /// Creates an instance of . + /// + /// Type of response value. + /// Response value. + /// HTTP response. + /// Instance of . + public static OperationResponse Create(T value, HttpResponseMessage response = default) + { + var retryAfter = response?.Headers?.GetValues(Retry.RetryAfterHeader).FirstOrDefault(); + return new OperationResponse(value, retryAfter); + } + } +} diff --git a/src/NLU.DevOps.Luis.Shared/Retry.cs b/src/NLU.DevOps.Luis.Shared/Retry.cs new file mode 100644 index 0000000..55f0cda --- /dev/null +++ b/src/NLU.DevOps.Luis.Shared/Retry.cs @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace NLU.DevOps.Luis +{ + using System; + using System.Globalization; + using System.Linq; + using System.Net; + using System.Text.RegularExpressions; + using System.Threading; + using System.Threading.Tasks; + using Microsoft.Azure.CognitiveServices.Language.LUIS.Authoring.Models; +#if LUIS_V2 + using ErrorException = Microsoft.Azure.CognitiveServices.Language.LUIS.Runtime.Models.APIErrorException; +#endif + + internal static class Retry + { + public const string RetryAfterHeader = "Retry-After"; + + private static readonly Regex RetryAfterSecondsRegex = new Regex(@"^\d+$"); + + private static TimeSpan DefaultTransientDelay { get; } = TimeSpan.FromMilliseconds(100); + + public static TimeSpan GetRetryAfterDelay(string retryAfter, TimeSpan? defaultDelay = default) + { + if (retryAfter == null) + { + return defaultDelay ?? DefaultTransientDelay; + } + + if (RetryAfterSecondsRegex.IsMatch(retryAfter)) + { + return TimeSpan.FromSeconds(int.Parse(retryAfter, CultureInfo.InvariantCulture)); + } + + return DateTimeOffset.Parse(retryAfter, CultureInfo.InvariantCulture) - DateTimeOffset.Now; + } + + public static CancellationTokenHolder With(CancellationToken cancellationToken) + { + return new CancellationTokenHolder(cancellationToken); + } + + private static async Task OnTransientExceptionAsync( + Func> func, + Func statusCodeSelector, + Func retryAfterDelaySelector = default, + int retryCount = int.MaxValue, + CancellationToken cancellationToken = default) + where TException : Exception + { + var count = 0; + while (count++ < retryCount) + { + cancellationToken.ThrowIfCancellationRequested(); + + try + { + return await func().ConfigureAwait(false); + } + catch (TException ex) + when (count < retryCount && IsTransientStatusCode(statusCodeSelector(ex))) + { + var delay = GetRetryAfterDelay(retryAfterDelaySelector?.Invoke(ex)); + await Task.Delay(delay, cancellationToken).ConfigureAwait(false); + } + } + + throw new InvalidOperationException("Exception will be rethrown before reaching this point."); + } + + private static bool IsTransientStatusCode(HttpStatusCode statusCode) + { + return statusCode == HttpStatusCode.TooManyRequests + || (statusCode >= HttpStatusCode.InternalServerError + && statusCode != HttpStatusCode.HttpVersionNotSupported + && statusCode != HttpStatusCode.NotImplemented); + } + + public class CancellationTokenHolder + { + public CancellationTokenHolder(CancellationToken cancellationToken) + { + this.CancellationToken = cancellationToken; + } + + private CancellationToken CancellationToken { get; } + + public Task OnTransientErrorAsync(Func> func) + { + return OnTransientExceptionAsync( + func, + (ErrorException ex) => ex.Response.StatusCode, + (ErrorException ex) => ex.Response.Headers?[RetryAfterHeader]?.FirstOrDefault(), + cancellationToken: this.CancellationToken); + } + + public Task OnTransientErrorResponseAsync(Func> func) + { + return OnTransientExceptionAsync( + func, + (ErrorResponseException ex) => ex.Response.StatusCode, + (ErrorResponseException ex) => ex.Response.Headers?[RetryAfterHeader]?.FirstOrDefault(), + cancellationToken: this.CancellationToken); + } + + public Task OnTransientWebExceptionAsync(Func> func) + { + return OnTransientExceptionAsync( + func, + (WebException ex) => (ex.Response as HttpWebResponse)?.StatusCode ?? default, + (WebException ex) => (ex.Response as HttpWebResponse)?.Headers?[RetryAfterHeader], + cancellationToken: this.CancellationToken); + } + } + } +} \ No newline at end of file diff --git a/src/NLU.DevOps.Luis.Tests.Shared/LuisNLUTrainClientTests.cs b/src/NLU.DevOps.Luis.Tests.Shared/LuisNLUTrainClientTests.cs index 1305749..9e9dbfa 100644 --- a/src/NLU.DevOps.Luis.Tests.Shared/LuisNLUTrainClientTests.cs +++ b/src/NLU.DevOps.Luis.Tests.Shared/LuisNLUTrainClientTests.cs @@ -215,13 +215,13 @@ public static async Task TrainingStatusDelayBetweenPolling() It.Is(appId => appId == builder.AppId), It.IsAny(), It.IsAny())) - .Returns(() => Task.FromResult>(new[] + .Returns(() => Task.FromResult(OperationResponse.Create((IList)new[] { new ModelTrainingInfo { Details = new ModelTrainingDetails { Status = statusArray[count++] } } - })) + }))) .Callback(() => timestamps[count - 1] = DateTimeOffset.Now); using (var luis = builder.Build()) @@ -251,13 +251,13 @@ public static void TrainingFailedThrowsInvalidOperation() It.Is(appId => appId == builder.AppId), It.IsAny(), It.IsAny())) - .Returns(() => Task.FromResult>(new[] + .Returns(() => Task.FromResult(OperationResponse.Create((IList)new[] { new ModelTrainingInfo { Details = new ModelTrainingDetails { Status = "Fail", FailureReason = failureReason } } - })); + }))); using (var luis = builder.Build()) { @@ -377,8 +377,9 @@ private class LuisNLUTrainClientBuilder public LuisNLUTrainClient Build() { this.MockLuisTrainClient.SetReturnsDefault( - Task.FromResult>( - Array.Empty())); + Task.FromResult( + OperationResponse.Create>( + Array.Empty()))); var luisConfiguration = new LuisConfiguration(new ConfigurationBuilder() .AddInMemoryCollection(new Dictionary diff --git a/src/NLU.DevOps.Luis/LuisTestClient.cs b/src/NLU.DevOps.Luis/LuisTestClient.cs index 0920e8a..4f42c43 100644 --- a/src/NLU.DevOps.Luis/LuisTestClient.cs +++ b/src/NLU.DevOps.Luis/LuisTestClient.cs @@ -21,8 +21,6 @@ namespace NLU.DevOps.Luis internal sealed class LuisTestClient : ILuisTestClient { - private static readonly TimeSpan ThrottleQueryDelay = TimeSpan.FromMilliseconds(100); - public LuisTestClient(ILuisConfiguration luisConfiguration) { this.LuisConfiguration = luisConfiguration ?? throw new ArgumentNullException(nameof(luisConfiguration)); @@ -42,27 +40,17 @@ public LuisTestClient(ILuisConfiguration luisConfiguration) private LUISRuntimeClient RuntimeClient { get; } - public async Task QueryAsync(string text, CancellationToken cancellationToken) + public Task QueryAsync(string text, CancellationToken cancellationToken) { - while (true) - { - try - { - return await this.RuntimeClient.Prediction.ResolveAsync( - this.LuisConfiguration.AppId, - text, - staging: this.LuisConfiguration.IsStaging, - log: false, - cancellationToken: cancellationToken) - .ConfigureAwait(false); - } - catch (APIErrorException ex) - when (IsTransientStatusCode(ex.Response.StatusCode)) - { - Logger.LogTrace($"Received HTTP {(int)ex.Response.StatusCode} result from Cognitive Services. Retrying."); - await Task.Delay(ThrottleQueryDelay, cancellationToken).ConfigureAwait(false); - } - } + Func> func = () => + this.RuntimeClient.Prediction.ResolveAsync( + this.LuisConfiguration.AppId, + text, + staging: this.LuisConfiguration.IsStaging, + log: false, + cancellationToken: cancellationToken); + + return Retry.OnTransientErrorAsync(func, cancellationToken); } public Task RecognizeSpeechAsync(string speechFile, CancellationToken cancellationToken) @@ -77,14 +65,6 @@ public void Dispose() this.RuntimeClient.Dispose(); } - private static bool IsTransientStatusCode(HttpStatusCode statusCode) - { - return statusCode == HttpStatusCode.TooManyRequests - || (statusCode >= HttpStatusCode.InternalServerError - && statusCode != HttpStatusCode.NotImplemented - && statusCode != HttpStatusCode.HttpVersionNotSupported); - } - private async Task RecognizeSpeechWithIntentRecognizerAsync(string speechFile) { if (this.LuisConfiguration.IsStaging) @@ -136,32 +116,22 @@ private async Task RecognizeSpeechWithEndpointAsync(string spe request.Accept = "application/json"; request.Headers.Add("Ocp-Apim-Subscription-Key", this.LuisConfiguration.SpeechKey); - JObject responseJson; - while (true) - { - try - { - using (var fileStream = File.OpenRead(speechFile)) - using (var requestStream = await request.GetRequestStreamAsync().ConfigureAwait(false)) + var jsonPayload = await Retry.With(cancellationToken).OnTransientWebExceptionAsync(async () => { - await fileStream.CopyToAsync(requestStream).ConfigureAwait(false); - using (var response = await request.GetResponseAsync().ConfigureAwait(false)) - using (var streamReader = new StreamReader(response.GetResponseStream())) + using (var fileStream = File.OpenRead(speechFile)) + using (var requestStream = await request.GetRequestStreamAsync().ConfigureAwait(false)) { - var responseText = await streamReader.ReadToEndAsync().ConfigureAwait(false); - responseJson = JObject.Parse(responseText); - break; + await fileStream.CopyToAsync(requestStream).ConfigureAwait(false); + using (var response = await request.GetResponseAsync().ConfigureAwait(false)) + using (var streamReader = new StreamReader(response.GetResponseStream())) + { + return await streamReader.ReadToEndAsync().ConfigureAwait(false); + } } - } - } - catch (WebException ex) - when (ex.Response is HttpWebResponse response && IsTransientStatusCode(response.StatusCode)) - { - Logger.LogTrace($"Received HTTP {(int)response.StatusCode} result from Cognitive Services. Retrying."); - await Task.Delay(ThrottleQueryDelay, cancellationToken).ConfigureAwait(false); - } - } + }) + .ConfigureAwait(false); + var responseJson = JObject.Parse(jsonPayload); if (responseJson.Value("RecognitionStatus") != "Success") { throw new InvalidOperationException($"Received error from LUIS speech service: {responseJson}"); diff --git a/src/NLU.DevOps.LuisV3/LuisTestClient.cs b/src/NLU.DevOps.LuisV3/LuisTestClient.cs index 245debe..ebcc753 100644 --- a/src/NLU.DevOps.LuisV3/LuisTestClient.cs +++ b/src/NLU.DevOps.LuisV3/LuisTestClient.cs @@ -17,8 +17,6 @@ namespace NLU.DevOps.Luis internal sealed class LuisTestClient : ILuisTestClient { - private static readonly TimeSpan ThrottleQueryDelay = TimeSpan.FromMilliseconds(100); - public LuisTestClient(ILuisConfiguration luisConfiguration) { this.LuisConfiguration = luisConfiguration ?? throw new ArgumentNullException(nameof(luisConfiguration)); @@ -39,41 +37,25 @@ public LuisTestClient(ILuisConfiguration luisConfiguration) private bool QueryTargetTraced { get; set; } - public async Task QueryAsync(PredictionRequest predictionRequest, CancellationToken cancellationToken) + public Task QueryAsync(PredictionRequest predictionRequest, CancellationToken cancellationToken) { - while (true) - { - try - { - this.TraceQueryTarget(); - if (this.LuisConfiguration.DirectVersionPublish) - { - return await this.RuntimeClient.Prediction.GetVersionPredictionAsync( - Guid.Parse(this.LuisConfiguration.AppId), - this.LuisConfiguration.VersionId, - predictionRequest, - verbose: true, - log: false, - cancellationToken: cancellationToken) - .ConfigureAwait(false); - } - - return await this.RuntimeClient.Prediction.GetSlotPredictionAsync( + this.TraceQueryTarget(); + return Retry.With(cancellationToken).OnTransientErrorAsync(() => + this.LuisConfiguration.DirectVersionPublish + ? this.RuntimeClient.Prediction.GetVersionPredictionAsync( + Guid.Parse(this.LuisConfiguration.AppId), + this.LuisConfiguration.VersionId, + predictionRequest, + verbose: true, + log: false, + cancellationToken: cancellationToken) + : this.RuntimeClient.Prediction.GetSlotPredictionAsync( Guid.Parse(this.LuisConfiguration.AppId), this.LuisConfiguration.SlotName, predictionRequest, verbose: true, log: false, - cancellationToken: cancellationToken) - .ConfigureAwait(false); - } - catch (ErrorException ex) - when (IsTransientStatusCode(ex.Response.StatusCode)) - { - Logger.LogTrace($"Received HTTP {(int)ex.Response.StatusCode} result from Cognitive Services. Retrying."); - await Task.Delay(ThrottleQueryDelay, cancellationToken).ConfigureAwait(false); - } - } + cancellationToken: cancellationToken)); } public async Task RecognizeSpeechAsync(string speechFile, PredictionRequest predictionRequest, CancellationToken cancellationToken) @@ -91,32 +73,22 @@ public async Task RecognizeSpeechAsync(string speechFi request.Accept = "application/json"; request.Headers.Add("Ocp-Apim-Subscription-Key", this.LuisConfiguration.SpeechKey); - JObject responseJson; - while (true) - { - try - { - using (var fileStream = File.OpenRead(speechFile)) - using (var requestStream = await request.GetRequestStreamAsync().ConfigureAwait(false)) + var jsonPayload = await Retry.With(cancellationToken).OnTransientWebExceptionAsync(async () => { - await fileStream.CopyToAsync(requestStream).ConfigureAwait(false); - using (var response = await request.GetResponseAsync().ConfigureAwait(false)) - using (var streamReader = new StreamReader(response.GetResponseStream())) + using (var fileStream = File.OpenRead(speechFile)) + using (var requestStream = await request.GetRequestStreamAsync().ConfigureAwait(false)) { - var responseText = await streamReader.ReadToEndAsync().ConfigureAwait(false); - responseJson = JObject.Parse(responseText); - break; + await fileStream.CopyToAsync(requestStream).ConfigureAwait(false); + using (var response = await request.GetResponseAsync().ConfigureAwait(false)) + using (var streamReader = new StreamReader(response.GetResponseStream())) + { + return await streamReader.ReadToEndAsync().ConfigureAwait(false); + } } - } - } - catch (WebException ex) - when (ex.Response is HttpWebResponse response && IsTransientStatusCode(response.StatusCode)) - { - Logger.LogTrace($"Received HTTP {(int)response.StatusCode} result from Cognitive Services. Retrying."); - await Task.Delay(ThrottleQueryDelay, cancellationToken).ConfigureAwait(false); - } - } + }) + .ConfigureAwait(false); + var responseJson = JObject.Parse(jsonPayload); if (responseJson.Value("RecognitionStatus") != "Success") { throw new InvalidOperationException($"Received error from LUIS speech service: {responseJson}"); @@ -143,14 +115,6 @@ public void Dispose() this.RuntimeClient.Dispose(); } - private static bool IsTransientStatusCode(HttpStatusCode statusCode) - { - return statusCode == HttpStatusCode.TooManyRequests - || (statusCode >= HttpStatusCode.InternalServerError - && statusCode != HttpStatusCode.HttpVersionNotSupported - && statusCode != HttpStatusCode.NotImplemented); - } - private void TraceQueryTarget() { if (!this.QueryTargetTraced)