// -------------------------------------------------------------------------------------------------------------------- // // Copyright (c) by respective owners including Yahoo!, Microsoft, and // individual contributors. All rights reserved. Released under a BSD // license as described in the file LICENSE. // // -------------------------------------------------------------------------------------------------------------------- using System; using System.Collections.Generic; using System.Diagnostics.Contracts; using System.Linq; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Dataflow; namespace VW { /// /// VW wrapper supporting multi-core learning by utilizing thread-based allreduce. /// public class VowpalWabbitThreadedLearning : IDisposable { /// /// Random generator used by uniform random example distributor. /// /// Initialized with static seed to enable reproducability. private readonly Random random = new Random(42); /// /// Configurable example distribution function choosing the vw instance for the next example. /// private readonly Func exampleDistributor; /// /// Native vw instances setup for thread-based allreduce /// private VowpalWabbit[] vws; /// /// Worker threads with a nice message queue infront that will start blocking once it's too full. /// private readonly ActionBlock>[] actionBlocks; /// /// The only offer non-blocking methods. Getting observers and calling OnNext() enables /// blocking once the queue is full. /// private readonly IObserver>[] observers; /// /// Invoked right after the root node performed AllReduce with the other instances. /// private readonly ConcurrentList> syncActions; /// /// Task enable waiting for clients on completion after all action blocks have finished (incl. cleanup). /// private Task[] completionTasks; /// /// Number of examples seen sofar. Used by round robin example distributor. /// private int exampleCount; /// /// Initializes a new instance of the class. /// /// Common settings used for vw instances. public VowpalWabbitThreadedLearning(VowpalWabbitSettings settings) { if (settings == null) throw new ArgumentNullException("settings"); if (settings.ParallelOptions == null) throw new ArgumentNullException("settings.ParallelOptions must be set"); Contract.EndContractBlock(); this.Settings = settings; if (this.Settings.ParallelOptions.CancellationToken == null) this.Settings.ParallelOptions.CancellationToken = new CancellationToken(); switch (this.Settings.ExampleDistribution) { case VowpalWabbitExampleDistribution.UniformRandom: this.exampleDistributor = _ => this.random.Next(this.observers.Length); break; case VowpalWabbitExampleDistribution.RoundRobin: this.exampleDistributor = localExampleCount => (int)(localExampleCount % this.observers.Length); break; } this.exampleCount = 0; this.syncActions = new ConcurrentList>(); this.vws = new VowpalWabbit[settings.ParallelOptions.MaxDegreeOfParallelism]; this.actionBlocks = new ActionBlock>[settings.ParallelOptions.MaxDegreeOfParallelism]; this.observers = new IObserver>[settings.ParallelOptions.MaxDegreeOfParallelism]; // setup AllReduce chain // root closure { var nodeSettings = (VowpalWabbitSettings)settings.Clone(); nodeSettings.Node = 0; var vw = this.vws[0] = new VowpalWabbit(nodeSettings); var actionBlock = this.actionBlocks[0] = new ActionBlock>( action => action(vw), new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = 1, TaskScheduler = settings.ParallelOptions.TaskScheduler, CancellationToken = settings.ParallelOptions.CancellationToken, BoundedCapacity = (int)settings.MaxExampleQueueLengthPerInstance }); } for (int i = 1; i < settings.ParallelOptions.MaxDegreeOfParallelism; i++) { // closure vars var nodeSettings = (VowpalWabbitSettings)settings.Clone(); nodeSettings.Root = this.vws[0]; nodeSettings.Node = (uint)i; var vw = this.vws[i] = new VowpalWabbit(nodeSettings); var actionBlock = this.actionBlocks[i] = new ActionBlock>( action => action(vw), new ExecutionDataflowBlockOptions { MaxDegreeOfParallelism = 1, TaskScheduler = settings.ParallelOptions.TaskScheduler, CancellationToken = settings.ParallelOptions.CancellationToken, BoundedCapacity = (int)settings.MaxExampleQueueLengthPerInstance }); } // get observers to allow for blocking calls this.observers = this.actionBlocks.Select(ab => ab.AsObserver()).ToArray(); this.completionTasks = new Task[settings.ParallelOptions.MaxDegreeOfParallelism]; // root closure { var vw = this.vws[0]; this.completionTasks[0] = this.actionBlocks[0].Completion .ContinueWith(_ => { // perform final AllReduce vw.EndOfPass(); // execute synchronization actions foreach (var syncAction in this.syncActions.RemoveAll()) { syncAction(vw); } }); } for (int i = 1; i < this.vws.Length; i++) { // perform final AllReduce var vw = this.vws[i]; this.completionTasks[i] = this.actionBlocks[i].Completion .ContinueWith(_ => vw.EndOfPass(), this.Settings.ParallelOptions.CancellationToken); } } /// /// VowpalWabbit instances participating in AllReduce. /// public VowpalWabbit[] VowpalWabbits { get { return this.vws; } } /// /// Creates a new instance of to feed examples of type . /// /// The user example type. /// A new instance of . public VowpalWabbitAsync Create() { return new VowpalWabbitAsync(this); } /// /// Creates a new instance of to feed multi-line /// examples of type and . /// /// The user example type. /// The user action dependent feature type. /// A new instance of . public VowpalWabbitAsync Create() { return new VowpalWabbitAsync(this); } /// /// Everytime examples have been enqueued, /// an AllReduce-sync operation () is injected. /// /// The number of examples enqueued so far. private uint CheckEndOfPass() { var exampleCount = (uint)Interlocked.Increment(ref this.exampleCount); // since there is no lock the input queue, it's not guaranteed that exactly // that number of examples are processed (but maybe a few more). if (exampleCount % this.Settings.ExampleCountPerRun == 0) { this.observers[0].OnNext(vw => { // perform AllReduce vw.EndOfPass(); // execute synchronization actions foreach (var syncAction in this.syncActions.RemoveAll()) { syncAction(vw); } }); for (int i = 1; i < this.observers.Length; i++) { // perform AllReduce this.observers[i].OnNext(vw => vw.EndOfPass()); } } return exampleCount; } /// /// Enqueues an action to be executed on one of vw instances. /// /// The action to be executed (e.g. Learn/Predict/...). /// If number of actions waiting to be executed has reached this method blocks. public void Post(Action action) { Contract.Requires(action != null); var exampleCount = this.CheckEndOfPass(); // dispatch this.observers[this.exampleDistributor(exampleCount)].OnNext(action); } /// /// Enqueues a task to be executed by single VowpalWabbit instance. /// /// Which VowpalWabbit instance chosen, is determined by . /// The return type of the task. /// The task to be executed. /// The awaitable result of the supplied task. internal Task Post(Func func) { Contract.Requires(func!= null); var exampleCount = this.CheckEndOfPass(); var completionSource = new TaskCompletionSource(); // dispatch to a Vowpal Wabbit instance this.observers[this.exampleDistributor(exampleCount)].OnNext(vw => { try { completionSource.SetResult(func(vw)); } catch (Exception ex) { completionSource.SetException(ex); } }); return completionSource.Task; } /// /// Learns from the given example. /// /// The example to learn. public void Learn(string line) { Contract.Requires(line != null); this.Post(vw => vw.Learn(line)); } /// /// Learns from the given example. /// /// The multi-line example to learn. public void Learn(IEnumerable lines) { Contract.Requires(lines != null); this.Post(vw => vw.Learn(lines)); } /// /// Synchronized performance statistics. /// /// The task is only completed after synchronization of all instances, triggered example. public Task PerformanceStatistics { get { var completionSource = new TaskCompletionSource(); this.syncActions.Add(vw => completionSource.SetResult(vw.PerformanceStatistics)); return completionSource.Task; } } /// /// Signal that no more examples are send. /// /// Task completes once the learning and cleanup is done. public Task Complete() { // make sure no more sync actions are added, which might otherwise never been called this.syncActions.CompleteAdding(); foreach (var actionBlock in this.actionBlocks) { actionBlock.Complete(); } return Task.WhenAll(this.completionTasks); } /// /// Saves a model as part of the synchronization. /// /// Task compeletes once the model is saved. public Task SaveModel() { var completionSource = new TaskCompletionSource(); this.syncActions.Add(vw => { vw.SaveModel(); completionSource.SetResult(true); }); return completionSource.Task; } /// /// Saves a model as part of the synchronization. /// /// Task compeletes once the model is saved. public Task SaveModel(string filename) { Contract.Requires(!string.IsNullOrEmpty(filename)); var completionSource = new TaskCompletionSource(); this.syncActions.Add(vw => { vw.SaveModel(filename); completionSource.SetResult(true); }); return completionSource.Task; } /// /// The settings shared across all instances. /// public VowpalWabbitSettings Settings { get; private set; } /// /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources. /// public void Dispose() { this.Dispose(true); GC.SuppressFinalize(this); } private void Dispose(bool disposing) { if (disposing) { if (this.completionTasks != null) { // mark completion this.Complete() .Wait(this.Settings.ParallelOptions.CancellationToken); // wait for all actionblocks to finish Task.WhenAll(this.completionTasks) .Wait(this.Settings.ParallelOptions.CancellationToken); this.completionTasks = null; } if (this.vws != null) { foreach (var vw in this.vws) { vw.Dispose(); } this.vws = null; } } } /// /// Thread-safe list implementation supporting completion. /// /// The element type. private class ConcurrentList { private bool completed = false; private readonly List items = new List(); private readonly object lockObject = new object(); /// /// Adds an object to the end of the list. /// /// The object to be added to the list. /// Throws an if the as called previously. public void Add(T item) { lock (this.lockObject) { if (completed) { throw new InvalidOperationException("ConcurrentList has been marked completed."); } this.items.Add(item); } } /// /// Marks this list as complete. Any subsequent calls to will trigger an . /// public void CompleteAdding() { lock (this.lockObject) { this.completed = true; } } /// /// Removes all elements from the list. /// /// The elements removed. public T[] RemoveAll() { lock (this.lockObject) { var ret = this.items.ToArray(); this.items.Clear(); return ret; } } } } }