using System; using System.Globalization; using System.IO; using System.Linq; using Antlr4.Runtime; using Antlr4.Runtime.Atn; using Antlr4.Runtime.Tree; using Microsoft.VisualStudio.TestTools.UnitTesting; using VW; using VW.Serializer; namespace cs_unittest { internal static class VWTestHelper { internal static void ParseInput(string text, IParseTreeListener listener) { ParseInput(new AntlrInputStream(text), listener); } internal static void ParseInput(Stream stream, IParseTreeListener listener) { ParseInput(new UnbufferedCharStream(stream), listener); } internal static void ParseInput(ICharStream stream, IParseTreeListener listener) { // optimized for memory consumption var lexer = new VowpalWabbitLexer(stream) { TokenFactory = new CommonTokenFactory(copyText: true) }; var tokens = new UnbufferedTokenStream(lexer); var parser = new VowpalWabbitParser(tokens) { // Note; don't disable, as it is required to access the line // BuildParseTree = false, }; // fast than LL(*) parser.Interpreter.PredictionMode = PredictionMode.Sll; parser.AddParseListener(listener); parser.AddErrorListener(new TestErrorListener()); parser.start(); } internal static void Learn(string args, string inputFile, string stderrFile) where TListener : VowpalWabbitListenerToEvents, new() { using (var vw = new VowpalWabbit(args)) using (var validate = new VowpalWabbitExampleValidator(args)) { var listener = new TListener(); listener.Created = (line, data, label) => { if (data == null) { Assert.Fail("got empty example"); } validate.Validate(line, data, label); vw.Learn(data, label); }; VWTestHelper.ParseInput(File.OpenRead(inputFile), listener); AssertEqual(stderrFile, vw.Native.PerformanceStatistics); } } internal static void Predict(string args, string inputFile, string referenceFile = null) where TData : BaseData where TListener : VowpalWabbitListenerToEvents, new() { float[] references = null; var index = 0; if (referenceFile != null) { references = File.ReadAllLines(referenceFile) .Select(l => float.Parse(l.Split(' ')[0], CultureInfo.InvariantCulture)) .ToArray(); } using (var vwRef = new VowpalWabbit(args)) using (var vwModel = new VowpalWabbitModel(args)) using (var vwValidate = new VowpalWabbit(args)) using (var vwInMemoryShared2 = new VowpalWabbit(new VowpalWabbitSettings { Model = vwModel })) using (var validate = new VowpalWabbitExampleValidator(args)) { var listener = new TListener(); listener.Created = (line, x, label) => { validate.Validate(line, x, label); var expectedDynamic = vwRef.Predict(x.Line, VowpalWabbitPredictionType.Dynamic); Assert.IsInstanceOfType(expectedDynamic, typeof(float)); var expected = vwRef.Predict(x.Line, VowpalWabbitPredictionType.Scalar); var actual = vwInMemoryShared2.Predict(x, VowpalWabbitPredictionType.Scalar, label); Assert.AreEqual((float)expectedDynamic, actual, 1e-5); Assert.AreEqual(expected, actual, 1e-5); if (references != null) Assert.AreEqual(references[index++], actual, 1e-5); }; } } internal static void AssertEqual(string expectedFile, VowpalWabbitPerformanceStatistics actual) { var expectedPerformanceStatistics = ReadPerformanceStatistics(expectedFile); AssertEqual(expectedPerformanceStatistics, actual); } internal static void FuzzyEqual(double? expected, double actual, double epsilon, string message) { if (expected == null) return; // from test/RunTests var delta = Math.Abs(expected.Value - actual); if (delta > epsilon) { // We have a 'big enough' difference, but this difference // may still not be meaningful in all contexts: // Big numbers should be compared by ratio rather than // by difference // Must ensure we can divide (avoid div-by-0) if (Math.Abs(actual) <= 1.0) { // If numbers are so small (close to zero), // ($delta > $Epsilon) suffices for deciding that // the numbers are meaningfully different Assert.Fail(string.Format("{0} vs {1}: delta={2} > Epsilon={3}: {4}", expected, actual, delta, epsilon, message)); } // Now we can safely divide (since abs($word2) > 0) // and determine the ratio difference from 1.0 var ratio_delta = Math.Abs(expected.Value / actual - 1.0); if (ratio_delta > epsilon) { Assert.Fail(string.Format("{0} vs {1}: delta={2} > Epsilon={3}: {4}", expected, actual, delta, epsilon, message)); } } } internal static void AssertEqual(VowpalWabbitPerformanceStatistics expected, VowpalWabbitPerformanceStatistics actual) { if (expected.TotalNumberOfFeatures != actual.TotalNumberOfFeatures) { Console.Error.WriteLine( "Warning: total number of features differs. Expected: {0} vs. actual: {1}", expected.TotalNumberOfFeatures, actual.TotalNumberOfFeatures); } Assert.AreEqual(expected.NumberOfExamplesPerPass, actual.NumberOfExamplesPerPass, "NumberOfExamplesPerPass"); FuzzyEqual(expected.AverageLoss, actual.AverageLoss, 1e-3, "AverageLoss"); FuzzyEqual(expected.BestConstant, actual.BestConstant, 1e-3, "BestConstant"); // TODO: something weir'd is happening here. BestConstantsLoss is 0 if using RunAll // has the proper value if just the unit test is run //Console.WriteLine(expected.BestConstantLoss + " vs. " + actual.BestConstantLoss); //Assert.AreEqual(expected.BestConstantLoss, actual.BestConstantLoss, 1e-5); FuzzyEqual(expected.WeightedExampleSum, actual.WeightedExampleSum, 1e-3, "WeightedExampleSum"); FuzzyEqual(expected.WeightedLabelSum, actual.WeightedLabelSum, 1e-3, "WeightedLabelSum"); } internal static void AssertEqual(VowpalWabbitStdErrPerformanceStatistics expected, VowpalWabbitPerformanceStatistics actual) { if (expected.TotalNumberOfFeatures != actual.TotalNumberOfFeatures) { Console.Error.WriteLine( "Warning: total number of features differs. Expected: {0} vs. actual: {1}", expected.TotalNumberOfFeatures, actual.TotalNumberOfFeatures); } if (expected.NumberOfExamplesPerPass != null) Assert.AreEqual(expected.NumberOfExamplesPerPass, actual.NumberOfExamplesPerPass, "NumberOfExamplesPerPass"); FuzzyEqual(expected.AverageLoss, actual.AverageLoss, 1e-3, "AverageLoss"); FuzzyEqual(expected.BestConstant, actual.BestConstant, 1e-3, "BestConstant"); // TODO: something weir'd is happening here. BestConstantsLoss is 0 if using RunAll // has the proper value if just the unit test is run //Console.WriteLine(expected.BestConstantLoss + " vs. " + actual.BestConstantLoss); //Assert.AreEqual(expected.BestConstantLoss, actual.BestConstantLoss, 1e-5); FuzzyEqual(expected.WeightedExampleSum, actual.WeightedExampleSum, 1e-3, "WeightedExampleSum"); FuzzyEqual(expected.WeightedLabelSum, actual.WeightedLabelSum, 1e-3, "WeightedLabelSum"); } internal static VowpalWabbitStdErrPerformanceStatistics ReadPerformanceStatistics(string filename) { var lines = File.ReadAllLines(filename); var numExamples = FindULongEntry(lines, "number of examples per pass = "); if (numExamples == 0) numExamples = FindULongEntry(lines, "number of examples = "); var stats = new VowpalWabbitStdErrPerformanceStatistics() { NumberOfExamplesPerPass = numExamples, TotalNumberOfFeatures = FindULongEntry(lines, "total feature number = "), AverageLoss = FindAverageLossEntry(lines), BestConstant = FindDoubleEntry(lines, "best constant = "), BestConstantLoss = FindDoubleEntry(lines, "best constant's loss = "), WeightedExampleSum = FindDoubleEntry(lines, "weighted example sum = "), WeightedLabelSum = FindDoubleEntry(lines, "weighted label sum = ") }; return stats; } private static double? FindAverageLossEntry(string[] lines) { var label = "average loss = "; var candidate = lines.FirstOrDefault(l => l.StartsWith(label)); if (candidate == null) { return null; } candidate = candidate.Substring(label.Length); if (candidate.EndsWith(" h")) { candidate = candidate.Substring(0, candidate.Length - 2); } var ret = 0.0; if (double.TryParse(candidate, NumberStyles.Float, CultureInfo.InvariantCulture, out ret)) { return ret; } return null; } private static double? FindDoubleEntry(string[] lines, string label) { var candidate = lines.FirstOrDefault(l => l.StartsWith(label)); if (candidate == null) { return null; } var ret = 0.0; if (double.TryParse(candidate.Substring(label.Length), NumberStyles.Float, CultureInfo.InvariantCulture, out ret)) { return ret; } return null; } private static ulong? FindULongEntry(string[] lines, string label) { var candidate = lines.FirstOrDefault(l => l.StartsWith(label)); if (candidate == null) { return null; } ulong ret = 0L; if (ulong.TryParse(candidate.Substring(label.Length), NumberStyles.Float, CultureInfo.InvariantCulture, out ret)) { return ret; } return null; } } }