图像分类-从位图对象加载.Microsoft.ML



我使用Visual Studio图形化助手创建模型。工作室为我生成了一个项目和代码。一切都适合我,除了在输入时我必须以字符串的形式给出文件的路径。我可以将图像以位图格式传输到神经网络吗?

我找到了很多例子,但它们都与我的代码不同,就像我有不同的版本。我试图适应我发现的代码,但我遇到了各种错误。

请解释目前Microsoft.ML 1.5如何做到这一点?如何适应生成的代码下面使用位图图像(不是路径输入)?

我ModelInput.cs

// This file was auto-generated by ML.NET Model Builder. 
using Microsoft.ML.Data;
namespace MLTestAppML.Model
{
public class ModelInput
{
[ColumnName("Label"), LoadColumn(0)]
public string Label { get; set; }

[ColumnName("ImageSource"), LoadColumn(1)]
public string ImageSource { get; set; }

}
}

我ModelBuilder.cs

// This file was auto-generated by ML.NET Model Builder. 
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using MLTestAppML.Model;
using Microsoft.ML.Vision;
namespace MLTestAppML.ConsoleApp
{
public static class ModelBuilder
{
private static string TRAIN_DATA_FILEPATH = @"C:UsersaaaAppDataLocalTempe43005d1-d83d-4f35-ab8d-7dbc3e693583.tsv";
private static string MODEL_FILEPATH = @"C:UsersaaaAppDataLocalTempMLVSToolsMLTestAppMLMLTestAppML.ModelMLModel.zip";
// Create MLContext to be shared across the model creation workflow objects 
// Set a random seed for repeatable/deterministic results across multiple trainings.
private static MLContext mlContext = new MLContext(seed: 1);
public static void CreateModel()
{
// Load Data
IDataView trainingDataView = mlContext.Data.LoadFromTextFile<ModelInput>(
path: TRAIN_DATA_FILEPATH,
hasHeader: true,
separatorChar: 't',
allowQuoting: true,
allowSparse: false);
// Build training pipeline
IEstimator<ITransformer> trainingPipeline = BuildTrainingPipeline(mlContext);
// Train Model
ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline);
// Evaluate quality of Model
Evaluate(mlContext, trainingDataView, trainingPipeline);
// Save model
SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema);
}
public static IEstimator<ITransformer> BuildTrainingPipeline(MLContext mlContext)
{
// Data process configuration with pipeline data transformations 
var dataProcessPipeline = mlContext.Transforms.Conversion.MapValueToKey("Label", "Label")
.Append(mlContext.Transforms.LoadRawImageBytes("ImageSource_featurized", null, "ImageSource"))
.Append(mlContext.Transforms.CopyColumns("Features", "ImageSource_featurized"));
// Set the training algorithm 
var trainer = mlContext.MulticlassClassification.Trainers.ImageClassification(new ImageClassificationTrainer.Options() { LabelColumnName = "Label", FeatureColumnName = "Features" })
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));
var trainingPipeline = dataProcessPipeline.Append(trainer);
return trainingPipeline;
}
public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
Console.WriteLine("=============== Training  model ===============");
ITransformer model = trainingPipeline.Fit(trainingDataView);
Console.WriteLine("=============== End of training process ===============");
return model;
}
private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator<ITransformer> trainingPipeline)
{
// Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
// in order to evaluate and get the model's accuracy metrics
Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ===============");
var crossValidationResults = mlContext.MulticlassClassification.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: 5, labelColumnName: "Label");
PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults);
}
private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema)
{
// Save/persist the trained model to a .ZIP file
Console.WriteLine($"=============== Saving the model  ===============");
mlContext.Model.Save(mlModel, modelInputSchema, GetAbsolutePath(modelRelativePath));
Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath));
}
public static string GetAbsolutePath(string relativePath)
{
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
string assemblyFolderPath = _dataRoot.Directory.FullName;
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
return fullPath;
}
public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics)
{
Console.WriteLine($"************************************************************");
Console.WriteLine($"*    Metrics for multi-class classification model   ");
Console.WriteLine($"*-----------------------------------------------------------");
Console.WriteLine($"    MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($"    MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better");
Console.WriteLine($"    LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better");
for (int i = 0; i < metrics.PerClassLogLoss.Count; i++)
{
Console.WriteLine($"    LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better");
}
Console.WriteLine($"************************************************************");
}
public static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable<TrainCatalogBase.CrossValidationResult<MulticlassClassificationMetrics>> crossValResults)
{
var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics);
var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy);
var microAccuracyAverage = microAccuracyValues.Average();
var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues);
var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues);
var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy);
var macroAccuracyAverage = macroAccuracyValues.Average();
var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues);
var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues);
var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss);
var logLossAverage = logLossValues.Average();
var logLossStdDeviation = CalculateStandardDeviation(logLossValues);
var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues);
var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);
var logLossReductionAverage = logLossReductionValues.Average();
var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues);
var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues);
Console.WriteLine($"*************************************************************************************************************");
Console.WriteLine($"*       Metrics for Multi-class Classification model      ");
Console.WriteLine($"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"*       Average MicroAccuracy:    {microAccuracyAverage:0.###}  - Standard deviation: ({microAccuraciesStdDeviation:#.###})  - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"*       Average MacroAccuracy:    {macroAccuracyAverage:0.###}  - Standard deviation: ({macroAccuraciesStdDeviation:#.###})  - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})");
Console.WriteLine($"*       Average LogLoss:          {logLossAverage:#.###}  - Standard deviation: ({logLossStdDeviation:#.###})  - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})");
Console.WriteLine($"*       Average LogLossReduction: {logLossReductionAverage:#.###}  - Standard deviation: ({logLossReductionStdDeviation:#.###})  - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})");
Console.WriteLine($"*************************************************************************************************************");
}
public static double CalculateStandardDeviation(IEnumerable<double> values)
{
double average = values.Average();
double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum();
double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1));
return standardDeviation;
}
public static double CalculateConfidenceInterval95(IEnumerable<double> values)
{
double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1));
return confidenceInterval95;
}
}
}

已更新

在较新版本的ModelBuilder上,它允许从byte[]预测,确保您有16.13.10.2241902版本。

你应该再次训练你的模型,然后你就可以用这种方式跑步了。

//Load sample data
var imageBytes = File.ReadAllBytes(@"C:UsersivmendozaDocumentsMIL CoPilot WorkspacesAlignersInBag2 Condtions DS TinyNonTypicalMX2-RND40K_20220725020628_PIDV_12443583U22N_12443583L22N.JPG");
AlignersCounterModel.ModelInput sampleData = new AlignersCounterModel.ModelInput()
{
ImageSource = imageBytes,
};
//Load model and predict output
var result = AlignersCounterModel.Predict(sampleData);


输入/输出类

public class ModelInput
{
[ColumnName(@"Label")]
public string Label { get; set; }
[ColumnName(@"ImageSource")]
public string ImageSource { get; set; }
}
public class ModelInputBytes
{
[ColumnName(@"Label")]
public string Label { get; set; }
[ColumnName(@"Features")]
public byte[] ImageBytes { get; set; }
}
public class ModelOutput
{
[ColumnName("PredictedLabel")]
public string Prediction { get; set; }
public float[] Score { get; set; }
}

消费代码

public static ModelOutput Predict(ModelInput input)
{
MLContext mlContext = new MLContext();
// Load model & create prediction engine
ITransformer mlModel = mlContext.Model.Load(MLNetModelPath, out var modelInputSchema);
ITransformer dataPreProcessTransform = LoadImageFromFileTransformer(input, mlContext);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(dataPreProcessTransform.Append(mlModel));
ModelOutput result = predEngine.Predict(input);
return result;
}
public static ITransformer LoadImageFromFileTransformer(ModelInput input, MLContext mlContext)
{
var dataPreProcess = mlContext.Transforms.Conversion.MapValueToKey(@"Label", @"Label")
.Append(mlContext.Transforms.LoadRawImageBytes(@"ImageSource_featurized", @"ImageSource"))
.Append(mlContext.Transforms.CopyColumns(@"Features", @"ImageSource_featurized"));
var dataView = mlContext.Data.LoadFromEnumerable(new[] { input });
var dataPreProcessTransform = dataPreProcess.Fit(dataView);
return dataPreProcessTransform;
}
public static ModelOutput PredictFromBytes(ModelInputBytes input)
{
MLContext mlContext = new MLContext();
// Load model & create prediction engine
ITransformer mlModel = mlContext.Model.Load(MLNetModelPath, out var modelInputSchema);
var predEngine = mlContext.Model.CreatePredictionEngine<ModelInputBytes, ModelOutput>(mlModel);
ModelOutput result = predEngine.Predict(input);
return result;
}

最新更新