// --------------------------------------------------------------------------------------------------------------------
//
// Copyright (c) by respective owners including Yahoo!, Microsoft, and
// individual contributors. All rights reserved. Released under a BSD
// license as described in the file LICENSE.
//
// --------------------------------------------------------------------------------------------------------------------
using System;
using System.Collections.Generic;
using System.Diagnostics.Contracts;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
namespace VW
{
///
/// VW wrapper supporting multi-core learning by utilizing thread-based allreduce.
///
public class VowpalWabbitThreadedLearning : IDisposable
{
///
/// Random generator used by uniform random example distributor.
///
/// Initialized with static seed to enable reproducability.
private readonly Random random = new Random(42);
///
/// Configurable example distribution function choosing the vw instance for the next example.
///
private readonly Func exampleDistributor;
///
/// Native vw instances setup for thread-based allreduce
///
private VowpalWabbit[] vws;
///
/// Worker threads with a nice message queue infront that will start blocking once it's too full.
///
private readonly ActionBlock>[] actionBlocks;
///
/// The only offer non-blocking methods. Getting observers and calling OnNext() enables
/// blocking once the queue is full.
///
private readonly IObserver>[] observers;
///
/// Invoked right after the root node performed AllReduce with the other instances.
///
private readonly ConcurrentList> syncActions;
///
/// Task enable waiting for clients on completion after all action blocks have finished (incl. cleanup).
///
private Task[] completionTasks;
///
/// Number of examples seen sofar. Used by round robin example distributor.
///
private int exampleCount;
///
/// Initializes a new instance of the class.
///
/// Common settings used for vw instances.
public VowpalWabbitThreadedLearning(VowpalWabbitSettings settings)
{
if (settings == null)
throw new ArgumentNullException("settings");
if (settings.ParallelOptions == null)
throw new ArgumentNullException("settings.ParallelOptions must be set");
Contract.EndContractBlock();
this.Settings = settings;
if (this.Settings.ParallelOptions.CancellationToken == null)
this.Settings.ParallelOptions.CancellationToken = new CancellationToken();
switch (this.Settings.ExampleDistribution)
{
case VowpalWabbitExampleDistribution.UniformRandom:
this.exampleDistributor = _ => this.random.Next(this.observers.Length);
break;
case VowpalWabbitExampleDistribution.RoundRobin:
this.exampleDistributor = localExampleCount => (int)(localExampleCount % this.observers.Length);
break;
}
this.exampleCount = 0;
this.syncActions = new ConcurrentList>();
this.vws = new VowpalWabbit[settings.ParallelOptions.MaxDegreeOfParallelism];
this.actionBlocks = new ActionBlock>[settings.ParallelOptions.MaxDegreeOfParallelism];
this.observers = new IObserver>[settings.ParallelOptions.MaxDegreeOfParallelism];
// setup AllReduce chain
// root closure
{
var nodeSettings = (VowpalWabbitSettings)settings.Clone();
nodeSettings.Node = 0;
var vw = this.vws[0] = new VowpalWabbit(nodeSettings);
var actionBlock = this.actionBlocks[0] = new ActionBlock>(
action => action(vw),
new ExecutionDataflowBlockOptions
{
MaxDegreeOfParallelism = 1,
TaskScheduler = settings.ParallelOptions.TaskScheduler,
CancellationToken = settings.ParallelOptions.CancellationToken,
BoundedCapacity = (int)settings.MaxExampleQueueLengthPerInstance
});
}
for (int i = 1; i < settings.ParallelOptions.MaxDegreeOfParallelism; i++)
{
// closure vars
var nodeSettings = (VowpalWabbitSettings)settings.Clone();
nodeSettings.Root = this.vws[0];
nodeSettings.Node = (uint)i;
var vw = this.vws[i] = new VowpalWabbit(nodeSettings);
var actionBlock = this.actionBlocks[i] = new ActionBlock>(
action => action(vw),
new ExecutionDataflowBlockOptions
{
MaxDegreeOfParallelism = 1,
TaskScheduler = settings.ParallelOptions.TaskScheduler,
CancellationToken = settings.ParallelOptions.CancellationToken,
BoundedCapacity = (int)settings.MaxExampleQueueLengthPerInstance
});
}
// get observers to allow for blocking calls
this.observers = this.actionBlocks.Select(ab => ab.AsObserver()).ToArray();
this.completionTasks = new Task[settings.ParallelOptions.MaxDegreeOfParallelism];
// root closure
{
var vw = this.vws[0];
this.completionTasks[0] = this.actionBlocks[0].Completion
.ContinueWith(_ =>
{
// perform final AllReduce
vw.EndOfPass();
// execute synchronization actions
foreach (var syncAction in this.syncActions.RemoveAll())
{
syncAction(vw);
}
});
}
for (int i = 1; i < this.vws.Length; i++)
{
// perform final AllReduce
var vw = this.vws[i];
this.completionTasks[i] = this.actionBlocks[i].Completion
.ContinueWith(_ => vw.EndOfPass(), this.Settings.ParallelOptions.CancellationToken);
}
}
///
/// VowpalWabbit instances participating in AllReduce.
///
public VowpalWabbit[] VowpalWabbits
{
get { return this.vws; }
}
///
/// Creates a new instance of to feed examples of type .
///
/// The user example type.
/// A new instance of .
public VowpalWabbitAsync Create()
{
return new VowpalWabbitAsync(this);
}
///
/// Creates a new instance of to feed multi-line
/// examples of type and .
///
/// The user example type.
/// The user action dependent feature type.
/// A new instance of .
public VowpalWabbitAsync Create()
{
return new VowpalWabbitAsync(this);
}
///
/// Everytime examples have been enqueued,
/// an AllReduce-sync operation () is injected.
///
/// The number of examples enqueued so far.
private uint CheckEndOfPass()
{
var exampleCount = (uint)Interlocked.Increment(ref this.exampleCount);
// since there is no lock the input queue, it's not guaranteed that exactly
// that number of examples are processed (but maybe a few more).
if (exampleCount % this.Settings.ExampleCountPerRun == 0)
{
this.observers[0].OnNext(vw =>
{
// perform AllReduce
vw.EndOfPass();
// execute synchronization actions
foreach (var syncAction in this.syncActions.RemoveAll())
{
syncAction(vw);
}
});
for (int i = 1; i < this.observers.Length; i++)
{
// perform AllReduce
this.observers[i].OnNext(vw => vw.EndOfPass());
}
}
return exampleCount;
}
///
/// Enqueues an action to be executed on one of vw instances.
///
/// The action to be executed (e.g. Learn/Predict/...).
/// If number of actions waiting to be executed has reached this method blocks.
public void Post(Action action)
{
Contract.Requires(action != null);
var exampleCount = this.CheckEndOfPass();
// dispatch
this.observers[this.exampleDistributor(exampleCount)].OnNext(action);
}
///
/// Enqueues a task to be executed by single VowpalWabbit instance.
///
/// Which VowpalWabbit instance chosen, is determined by .
/// The return type of the task.
/// The task to be executed.
/// The awaitable result of the supplied task.
internal Task Post(Func func)
{
Contract.Requires(func!= null);
var exampleCount = this.CheckEndOfPass();
var completionSource = new TaskCompletionSource();
// dispatch to a Vowpal Wabbit instance
this.observers[this.exampleDistributor(exampleCount)].OnNext(vw =>
{
try
{
completionSource.SetResult(func(vw));
}
catch (Exception ex)
{
completionSource.SetException(ex);
}
});
return completionSource.Task;
}
///
/// Learns from the given example.
///
/// The example to learn.
public void Learn(string line)
{
Contract.Requires(line != null);
this.Post(vw => vw.Learn(line));
}
///
/// Learns from the given example.
///
/// The multi-line example to learn.
public void Learn(IEnumerable lines)
{
Contract.Requires(lines != null);
this.Post(vw => vw.Learn(lines));
}
///
/// Synchronized performance statistics.
///
/// The task is only completed after synchronization of all instances, triggered example.
public Task PerformanceStatistics
{
get
{
var completionSource = new TaskCompletionSource();
this.syncActions.Add(vw => completionSource.SetResult(vw.PerformanceStatistics));
return completionSource.Task;
}
}
///
/// Signal that no more examples are send.
///
/// Task completes once the learning and cleanup is done.
public Task Complete()
{
// make sure no more sync actions are added, which might otherwise never been called
this.syncActions.CompleteAdding();
foreach (var actionBlock in this.actionBlocks)
{
actionBlock.Complete();
}
return Task.WhenAll(this.completionTasks);
}
///
/// Saves a model as part of the synchronization.
///
/// Task compeletes once the model is saved.
public Task SaveModel()
{
var completionSource = new TaskCompletionSource();
this.syncActions.Add(vw =>
{
vw.SaveModel();
completionSource.SetResult(true);
});
return completionSource.Task;
}
///
/// Saves a model as part of the synchronization.
///
/// Task compeletes once the model is saved.
public Task SaveModel(string filename)
{
Contract.Requires(!string.IsNullOrEmpty(filename));
var completionSource = new TaskCompletionSource();
this.syncActions.Add(vw =>
{
vw.SaveModel(filename);
completionSource.SetResult(true);
});
return completionSource.Task;
}
///
/// The settings shared across all instances.
///
public VowpalWabbitSettings Settings
{
get;
private set;
}
///
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
///
public void Dispose()
{
this.Dispose(true);
GC.SuppressFinalize(this);
}
private void Dispose(bool disposing)
{
if (disposing)
{
if (this.completionTasks != null)
{
// mark completion
this.Complete()
.Wait(this.Settings.ParallelOptions.CancellationToken);
// wait for all actionblocks to finish
Task.WhenAll(this.completionTasks)
.Wait(this.Settings.ParallelOptions.CancellationToken);
this.completionTasks = null;
}
if (this.vws != null)
{
foreach (var vw in this.vws)
{
vw.Dispose();
}
this.vws = null;
}
}
}
///
/// Thread-safe list implementation supporting completion.
///
/// The element type.
private class ConcurrentList
{
private bool completed = false;
private readonly List items = new List();
private readonly object lockObject = new object();
///
/// Adds an object to the end of the list.
///
/// The object to be added to the list.
/// Throws an if the as called previously.
public void Add(T item)
{
lock (this.lockObject)
{
if (completed)
{
throw new InvalidOperationException("ConcurrentList has been marked completed.");
}
this.items.Add(item);
}
}
///
/// Marks this list as complete. Any subsequent calls to will trigger an .
///
public void CompleteAdding()
{
lock (this.lockObject)
{
this.completed = true;
}
}
///
/// Removes all elements from the list.
///
/// The elements removed.
public T[] RemoveAll()
{
lock (this.lockObject)
{
var ret = this.items.ToArray();
this.items.Clear();
return ret;
}
}
}
}
}