Source

Sepialabs.Azure / Sepialabs.Azure / ParallelTableQuery.cs

Full commit
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.WindowsAzure.StorageClient;
using System.Data.Services.Client;
using System.Threading.Tasks;
using System.Threading;
using Sepialabs.Utilities;

namespace SepiaLabs.Azure
{
    /// <summary>
    /// A class for performing multiple Azure Table queries in parallel, with client side timeouts.
    /// </summary>
    /// <remarks>
    /// Create the object, call AddQuery() for each query you want to run, then call Execute(). 
    /// This is based on a StackOverflow post: http://stackoverflow.com/questions/4535740/generic-class-for-performing-mass-parallel-queries-feedback
    /// Though it has been heavily modified from that base. 
    /// </remarks>
    public class ParallelTableQuery<T>
        : IDisposable
        where T : new()
    {
        /// <summary>
        /// Default timeout for the entire AsyncDataQuery. Equal to 90 seconds.
        /// </summary>
        public static readonly TimeSpan DEFAULT_OVERALL_TIMEOUT = TimeSpan.FromSeconds(90);

        /// <summary>
        /// Default timeout for each subquery. Equal to 5 seconds.
        /// </summary>
        public static readonly TimeSpan DEFAULT_QUERY_TIMEOUT = TimeSpan.FromSeconds(5);

        private ManualResetEvent waitEvent;
        
        /// <summary>
        /// Number of queries that are still processing
        /// </summary>
        private int numQueriesRemaining;

        /// <summary>
        /// Private list of the queries we're executing
        /// </summary>
        private List<QueryState> Queries { get; set; }
        
        /// <summary>
        /// Indicates whether the overall query has ended. When set to true, subqueries will not continue following continuation tokens.
        /// </summary>
        private bool overallQueryEnded;

        /// <summary>
        /// Timeout for subqueries
        /// </summary>
        private TimeSpan queryTimeout;
        
        /// <summary>
        /// Timeout for the overall query
        /// </summary>
        private TimeSpan overallTimeout;

        /// <summary>
        /// The time when the overall query should end
        /// </summary>
        private DateTime overallQueryEnd;

        /// <summary>
        /// Function to determine if an error should be retried
        /// </summary>
        private ShouldRetry retriableDecider;

        private ILogger log;

        /// <summary>
        /// Initialize a query, optionally with the specified timeouts
        /// </summary>
        /// <param name="overallTimeout">The amount of time allowed to run the entire set of queries. If this is exceeded any remaining queries are killed.</param>
        /// <param name="subQueryTimeout">The amount of time allowed for a single subquery before it will be killed and retried. Note that Azure queries are timed out on the server side after 30 seconds.</param>
        public ParallelTableQuery(ILogger logger = null, TimeSpan? overallTimeout = null, TimeSpan? subQueryTimeout = null, ShouldRetry retryDecider = null)
        {
            this.log = logger ?? NullLogger.Instance;
            this.Queries = new List<QueryState>();
            this.overallTimeout = overallTimeout ?? DEFAULT_OVERALL_TIMEOUT;
            this.queryTimeout = subQueryTimeout ?? DEFAULT_QUERY_TIMEOUT;
            this.retriableDecider = retryDecider ?? DefaultRetriableDecider;
        }

        /// <summary>
        /// Add a new query to the list to be executed.
        /// </summary>
        public void AddQuery(CloudTableQuery<T> query)
        {
            if (this.waitEvent != null)
            {
                throw new InvalidOperationException("Cannot add more queries after execution has already begun");
            }

            var qs = new QueryState(query);
            this.Queries.Add(qs);
        }

        /// <summary>
        /// Execute all the queries in parallel, return the results.
        /// Any exceptions will be thrown as an AggregateException
        /// </summary>
        public IEnumerable<T> Execute()
        {
            AggregateException ae;
            IEnumerable<T> results = this.Execute(out ae);

            if (ae != null)
            {
                throw ae;
            }
            else
            {
                return results;
            }
        }

        /// <summary>
        /// Execute all the queries in parallel, return the results. Any exceptions will be bundled into the out parameter,
        /// allowing you to use the partial results if you wish. 
        /// </summary>
        /// <param name="ae">A bundle of all the exceptions from any of the subqueries. Any queries that timed out will have a TimeoutException.</param>
        /// <returns>Enumerable results of the queries</returns>
        public IEnumerable<T> Execute(out AggregateException ae)
        {
            this.BeginAsync();
            return this.EndAsync(out ae);
        }

        /// <summary>
        /// Begin executing all of the subqueries.
        /// </summary>
        public void BeginAsync()
        {
            this.waitEvent = new ManualResetEvent(false);
            this.numQueriesRemaining = Queries.Count;
            this.overallQueryEnded = false;
            this.overallQueryEnd = DateTime.UtcNow.Add(this.overallTimeout);

            foreach (var queryState in Queries)
            {
                queryState.AttemptNumber = 1;
                RunQuery(queryState);
            }
        }

        /// <summary>
        /// Wait for the query list to finish executing.
        /// </summary>
        /// <param name="ae">A bundle of all the exceptions from any of the subqueries. Any queries that timed out will have a TimeoutException.</param>
        /// <returns>Enumerable results of the queries</returns>
        public IEnumerable<T> EndAsync(out AggregateException exceptions)
        {
            TimeSpan timeout = this.overallQueryEnd - DateTime.UtcNow;
            bool allQueriesCompleted = waitEvent.WaitOne(timeout);

            this.overallQueryEnded = true; //signal any remaining, timed out queries that we're done with them

            waitEvent.Dispose();
            waitEvent = null;

            //Select out all the results and flatten into one list
            IEnumerable<T> results = (from qs in this.Queries
                                      where qs.Results != null
                                      select qs.Results)
                                    .SelectMany(r => r);

            List<Exception> errors = (from qs in this.Queries
                                      where qs.Error != null
                                      select qs.Error).ToList();

            if (!allQueriesCompleted)
            {
                int numTimeouts = this.numQueriesRemaining;
                if (numTimeouts > 0)
                {
                    errors.Add(new TimeoutException(numTimeouts.ToString() + " of " + this.Queries.Count + " queries timed out"));
                }

                IEnumerable<QueryState> queriesNotCompleted = (from q in Queries
                                                               where !q.IsCompleted
                                                               select q);

                StringBuilder errorSb = new StringBuilder("The following queries were not completed during async query: \r\n");
                foreach (QueryState notComplete in queriesNotCompleted)
                {
                    errorSb.AppendFormat("IsCompleted: {0}, SegmentsCompleted: {1}, AttemptNumber: {2}, ResultCount: {3}, HasContinuationToken: {4}, HasError: {5}, Query: {6}\r\n",
                            notComplete.IsCompleted,
                            notComplete.NumSegmentsCompleted,
                            notComplete.AttemptNumber,
                            (notComplete.Results == null ? 0 : notComplete.Results.Length),
                            (notComplete.ContinuationToken == null ? "no" : "yes"),
                            (notComplete.Error == null ? "no" : "yes (" + (notComplete.Error.Message) + ")"),
                            notComplete.Query.ToString()
                        );
                }
                errorSb.AppendLine(Environment.StackTrace);
                log.Warn(errorSb.ToString());
            }

            if (errors.Any())
            {
                exceptions = new AggregateException(errors);
            }
            else
            {
                exceptions = null;
            }

            return results;
        }

        /// <summary>
        /// Callback function used for handling results of queries, or the time out of the query
        /// </summary>
        /// <param name="opState">Operation state passed in when the query was started</param>
        /// <param name="timedOut">Indicates whether the query timed out before being completed</param>
        private void QuerySegmentCompleted(object opState, bool timedOut)
        {
            IAsyncResult asyncResult = null;
            QueryState state = null;
            ResultSegment<T> response;

            try
            {
                asyncResult = opState as IAsyncResult;
                state = asyncResult.AsyncState as QueryState;

                //Unregister immediately to avoid multiple callback and to clean up resources immediately
                state.TimeoutWaitHandle.Unregister(asyncResult.AsyncWaitHandle);

                if (timedOut)
                {
                    RetryQuery(state);
                }
                else
                {
                    try
                    {
                        state.NumSegmentsCompleted++;

                        response = state.Query.EndExecuteSegmented(asyncResult);

                        //Save results into the QueryState object
                        state.AddResults(response.Results);

                        //Check for a continuation, and execute if necessary
                        if (response.HasMoreResults)
                        {
                            state.ContinuationToken = response.ContinuationToken;
                            state.AttemptNumber = 1;
                            RunQuery(state);
                        }
                        else
                        {
                            //No more results, we're done
                            QueryCompleted(state);
                        }
                    }
                    catch (Exception ex)
                    {
                        RetryQuery(state, ex);
                    }
                }
            }
            catch (Exception outerEx)
            {
                log.Error("Caught unexpected exception completing query segment: {0}", outerEx);

                if (state != null)
                {
                    state.Error = outerEx;
                    QueryCompleted(state);
                }
            }
        }

        /// <summary>
        /// Execute the query using whatever continuation token may be in the QueryState. 
        /// </summary>
        private void RunQuery(QueryState queryState)
        {
            if (this.overallQueryEnded)
            {
                //If the overall query has finished, then there's no point doing this
                QueryCompleted(queryState);
                return;
            }

            IAsyncResult asyncResult;

            if (queryState.ContinuationToken != null)
            {
                asyncResult = queryState.Query.BeginExecuteSegmented(queryState.ContinuationToken, NoopAsyncCallback, queryState);
            }
            else
            {
                asyncResult = queryState.Query.BeginExecuteSegmented(NoopAsyncCallback, queryState);
            }

            //The BeginExecuteSegmented() method will not respect any client side timeouts that we set 
            //So we have to do timeouts ourselves
            //RegisterWaitForSingleObject() will execute the callback when the query completes, or when the specified timeout has elapsed
            queryState.TimeoutWaitHandle = ThreadPool.RegisterWaitForSingleObject(asyncResult.AsyncWaitHandle, QuerySegmentCompleted, asyncResult, queryTimeout, true);
        }

        /// <summary>
        /// Retry an individual subquery, if the retry policy allows it and the overall query has not ended
        /// </summary>
        private void RetryQuery(QueryState state, Exception ex = null)
        {
            TimeSpan retryDelay;
            if (this.overallQueryEnded
                || !retriableDecider(state.AttemptNumber, ex, out retryDelay)
                || DateTime.UtcNow.Add(retryDelay) >= this.overallQueryEnd)
            {
                
                QueryCompleted(state);
            }
            else
            {
                state.AttemptNumber++;
                state.Error = ex;
                RunQuery(state);
            }
        }

        /// <summary>
        /// Mark a subquery as completed, update the completed query counter, and possibly notify any waiters
        /// </summary>
        /// <param name="qs"></param>
        private void QueryCompleted(QueryState qs)
        {
            qs.IsCompleted = true;

            int newVal = Interlocked.Decrement(ref numQueriesRemaining);
            if (newVal <= 0 && waitEvent != null)
            {
                waitEvent.Set();
            }
        }

        public void Dispose()
        {
            if (waitEvent != null)
            {
                waitEvent.Dispose();
                waitEvent = null;
            }
        }

        /// <summary>
        /// A no-op callback for the BeginExecuteSegmented() function, because we want to use the waithandle instead
        /// </summary>
        /// <param name="ar"></param>
        public void NoopAsyncCallback(IAsyncResult ar)
        {
            //I'm not doin' nuthin.
        }

        /// <summary>
        /// Basic retry implementation
        /// </summary>
        private bool DefaultRetriableDecider(int attempt, Exception ex, out TimeSpan delay)
        {
            if (attempt <= 3)
            {
                delay = TimeSpan.FromSeconds(5);
                return true;
            }
            else
            {
                delay = TimeSpan.Zero;
                return false;
            }
        }

        /// <summary>
        /// Tracks the state of each of the subqueries we're running. 
        /// </summary>
        private class QueryState
        {
            public QueryState(CloudTableQuery<T> query)
            {
                this.Query = query;
                this.AttemptNumber = 0;
                this.ContinuationToken = null;
                this.Results = null;
                this.IsCompleted = false;
            }

            public void AddResults(IEnumerable<T> newResult)
            {
                //Save results into the QueryState object
                if (this.Results == null)
                {
                    this.Results = newResult.ToArray();
                }
                else if (newResult != null)
                {
                    this.Results = Enumerable.Concat(this.Results, newResult).ToArray();
                }
            }

            public T[] Results;
            public CloudTableQuery<T> Query;
            public int AttemptNumber;
            public ResultContinuation ContinuationToken;
            public Exception Error;
            public bool IsCompleted;
            public int NumSegmentsCompleted = 0;
            public RegisteredWaitHandle TimeoutWaitHandle;
        }
    }
}