// This file was auto-generated by ML.NET Model Builder. using System; using System.Collections.Generic; using System.IO; using System.Linq; using Microsoft.ML; using Microsoft.ML.Data; using Mask_MLML.Model; using Microsoft.ML.Vision; namespace Mask_MLML.ConsoleApp { public static class ModelBuilder { private static string TRAIN_DATA_FILEPATH = @"C:\Users\HUAWEI\AppData\Local\Temp\b441e20c-6f79-4437-a21a-be0014456e13.tsv"; private static string MODEL_FILEPATH = @"C:\Users\HUAWEI\AppData\Local\Temp\MLVSTools\Mask.MLML\Mask.MLML.Model\MLModel.zip"; // Create MLContext to be shared across the model creation workflow objects // Set a random seed for repeatable/deterministic results across multiple trainings. private static MLContext mlContext = new MLContext(seed: 1); public static void CreateModel() { // Load Data IDataView trainingDataView = mlContext.Data.LoadFromTextFile( path: TRAIN_DATA_FILEPATH, hasHeader: true, separatorChar: '\t', allowQuoting: true, allowSparse: false); // Build training pipeline IEstimator trainingPipeline = BuildTrainingPipeline(mlContext); // Train Model ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline); // Evaluate quality of Model Evaluate(mlContext, trainingDataView, trainingPipeline); // Save model SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema); } public static IEstimator BuildTrainingPipeline(MLContext mlContext) { // Data process configuration with pipeline data transformations var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Label") .Append(mlContext.Transforms.LoadRawImageBytes("ImageSource_featurized", null, "ImageSource")) .Append(mlContext.Transforms.CopyColumns("Features", "ImageSource_featurized")); // Set the training algorithm var trainer = mlContext.MulticlassClassification.Trainers.ImageClassification(new ImageClassificationTrainer.Options() { LabelColumnName = "Label", FeatureColumnName = "Features" }) .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel")); var trainingPipeline = dataProcessPipeline.Append(trainer); return trainingPipeline; } public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) { Console.WriteLine("=============== Training model ==============="); ITransformer model = trainingPipeline.Fit(trainingDataView); Console.WriteLine("=============== End of training process ==============="); return model; } private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) { // Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate) // in order to evaluate and get the model's accuracy metrics Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ==============="); var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "Label"); PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults); } private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema) { // Save/persist the trained model to a .ZIP file Console.WriteLine($"=============== Saving the model ==============="); mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath)); Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); } public static string GetAbsolutePath(string relativePath) { FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location); string assemblyFolderPath = _dataRoot.Directory.FullName; string fullPath = Path.Combine(assemblyFolderPath, relativePath); return fullPath; } public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics) { Console.WriteLine($"************************************************************"); Console.WriteLine($"* Metrics for multi-class classification model "); Console.WriteLine($"*-----------------------------------------------------------"); Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better"); Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better"); Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better"); for (int i = 0; i < metrics.PerClassLogLoss.Count; i++) { Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better"); } Console.WriteLine($"************************************************************"); } public static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable> crossValResults) { var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy); var microAccuracyAverage = microAccuracyValues.Average(); var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues); var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues); var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy); var macroAccuracyAverage = macroAccuracyValues.Average(); var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues); var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues); var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss); var logLossAverage = logLossValues.Average(); var logLossStdDeviation = CalculateStandardDeviation(logLossValues); var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues); var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction); var logLossReductionAverage = logLossReductionValues.Average(); var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues); var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues); Console.WriteLine($"*************************************************************************************************************"); Console.WriteLine($"* Metrics for Multi-class Classification model "); Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})"); Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})"); Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})"); Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})"); Console.WriteLine($"*************************************************************************************************************"); } public static double CalculateStandardDeviation(IEnumerable values) { double average = values.Average(); double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum(); double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1)); return standardDeviation; } public static double CalculateConfidenceInterval95(IEnumerable values) { double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1)); return confidenceInterval95; } } }