163 lines
9.5 KiB
C#
163 lines
9.5 KiB
C#
// This file was auto-generated by ML.NET Model Builder.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using System.Linq;
|
|
using Microsoft.ML;
|
|
using Microsoft.ML.Data;
|
|
using Mask_MLML.Model;
|
|
using Microsoft.ML.Vision;
|
|
|
|
namespace Mask_MLML.ConsoleApp
|
|
{
|
|
public static class ModelBuilder
|
|
{
|
|
private static string TRAIN_DATA_FILEPATH = @"\b441e20c-6f79-4437-a21a-be0014456e13.tsv";
|
|
private static string MODEL_FILEPATH = @"\MLModel.zip";
|
|
// Create MLContext to be shared across the model creation workflow objects
|
|
// Set a random seed for repeatable/deterministic results across multiple trainings.
|
|
private static MLContext mlContext = new MLContext(seed: 1);
|
|
|
|
public static void CreateModel()
|
|
{
|
|
|
|
// Load Data
|
|
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
|
|
path: Path.Combine(AppDomain.CurrentDomain.BaseDirectory, TRAIN_DATA_FILEPATH),
|
|
hasHeader: true,
|
|
separatorChar: '\t',
|
|
allowQuoting: true,
|
|
allowSparse: false);
|
|
|
|
// Build training pipeline
|
|
IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext);
|
|
|
|
// Train Model
|
|
ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);
|
|
|
|
// Evaluate quality of Model
|
|
Evaluate(mlContext, trainingDataView, trainingPipeline);
|
|
|
|
// Save model
|
|
SaveModel(mlContext, mlModel, Path.Combine(AppDomain.CurrentDomain.BaseDirectory, MODEL_FILEPATH), trainingDataView.Schema);
|
|
}
|
|
|
|
public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
|
|
{
|
|
// Data process configuration with pipeline data transformations
|
|
var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Label")
|
|
.Append(mlContext.Transforms.LoadRawImageBytes("ImageSource_featurized", null, "ImageSource"))
|
|
.Append(mlContext.Transforms.CopyColumns("Features", "ImageSource_featurized"));
|
|
// Set the training algorithm
|
|
var trainer = mlContext.MulticlassClassification.Trainers.ImageClassification(new ImageClassificationTrainer.Options() { LabelColumnName = "Label", FeatureColumnName = "Features" })
|
|
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
|
|
|
|
var trainingPipeline = dataProcessPipeline.Append(trainer);
|
|
|
|
return trainingPipeline;
|
|
}
|
|
|
|
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
|
|
{
|
|
Console.WriteLine("=============== Training model ===============");
|
|
|
|
ITransformer model = trainingPipeline.Fit(trainingDataView);
|
|
|
|
Console.WriteLine("=============== End of training process ===============");
|
|
return model;
|
|
}
|
|
|
|
private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
|
|
{
|
|
// Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
|
|
// in order to evaluate and get the model's accuracy metrics
|
|
Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
|
|
var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "Label");
|
|
PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults);
|
|
}
|
|
|
|
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
|
|
{
|
|
// Save/persist the trained model to a .ZIP file
|
|
Console.WriteLine($"=============== Saving the model ===============");
|
|
mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
|
|
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
|
|
}
|
|
|
|
public static string GetAbsolutePath(string relativePath)
|
|
{
|
|
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
|
|
string assemblyFolderPath = _dataRoot.Directory.FullName;
|
|
|
|
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
|
|
|
|
return fullPath;
|
|
}
|
|
|
|
public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics)
|
|
{
|
|
Console.WriteLine($"************************************************************");
|
|
Console.WriteLine($"* Metrics for multi-class classification model ");
|
|
Console.WriteLine($"*-----------------------------------------------------------");
|
|
Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
|
|
Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
|
|
Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
|
|
for (int i = 0; i < metrics.PerClassLogLoss.Count; i++)
|
|
{
|
|
Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
|
|
}
|
|
Console.WriteLine($"************************************************************");
|
|
}
|
|
|
|
public static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> crossValResults)
|
|
{
|
|
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
|
|
|
|
var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
|
|
var microAccuracyAverage = microAccuracyValues.Average();
|
|
var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
|
|
var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
|
|
|
|
var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
|
|
var macroAccuracyAverage = macroAccuracyValues.Average();
|
|
var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
|
|
var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);
|
|
|
|
var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
|
|
var logLossAverage = logLossValues.Average();
|
|
var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
|
|
var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);
|
|
|
|
var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
|
|
var logLossReductionAverage = logLossReductionValues.Average();
|
|
var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
|
|
var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);
|
|
|
|
Console.WriteLine($"*************************************************************************************************************");
|
|
Console.WriteLine($"* Metrics for Multi-class Classification model ");
|
|
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
|
|
Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
|
|
Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
|
|
Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
|
|
Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
|
|
Console.WriteLine($"*************************************************************************************************************");
|
|
|
|
}
|
|
|
|
public static double CalculateStandardDeviation(IEnumerable<double> values)
|
|
{
|
|
double average = values.Average();
|
|
double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
|
|
double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));
|
|
return standardDeviation;
|
|
}
|
|
|
|
public static double CalculateConfidenceInterval95(IEnumerable<double> values)
|
|
{
|
|
double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
|
|
return confidenceInterval95;
|
|
}
|
|
}
|
|
}
|