In the previous article, we finished the first part of our example project, now we have SentimentAnalyst class which we can use for training data and making a prediction by passing real data to it.
Before starting Trainer Project, I want to talk a bit about the data which we are going to use for training and for testing. This dataset downloaded from kaggle.
You can download to dataset from this link.
You can download the source code from here.
Below you can see the details of the dataset, it is a .csv format which is possible to use by the ML.Net.
There is one point which we need to pay attention, as you can see below, sentiment fields are "String" data, but we need numeric value instead of it, because we are going to make a prediction as "Positive" or "Negative" we need two "Class" as "0" for negative and "1" for positive, that's why I updated all "negative" word as "0", and all "positive" words as "1", the dataset with the new values is ready to use.
Now after we talked about the dataset which we are going to use to train our ML.Net, we can start to talk about trainer project.
The Trainer project is the second part of our example project, which is very critical, all the predictions which our example project is going to make depend on the quality of the training data on this stage.
Let's start by adding a new project our solution, we are going to add Console App to our solution. Please select "Console App" and then click "Next"
Now after we talked about the dataset which we are going to use to train our ML.Net, we can start to talk about trainer project.
The Trainer project is the second part of our example project, which is very critical, all the predictions which our example project is going to make depend on the quality of the training data on this stage.
Let's start by adding a new project our solution, we are going to add Console App to our solution. Please select "Console App" and then click "Next"
After our project created and added to our solution, we are going to create a "Data" folder in the "Trainer" project, and we are going to copy our dataset which we are going to use for training and testing ML.Net, project should looks below.
After our project created and added to our solution, we are going to create a "Data" folder in the "Trainer" project, and we are going to copy our dataset which we are going to use for training and testing ML.Net, project should looks below.
Now, we can start coding, as we already talked, this project is going to handle the "Training" process, actually we already have the class which is called SentimentAnalyst, and it is going to do all the process about training, this trainer project is the layer which is going to pass all necessary parameters to the SentimentAnalyst and then displays training result which will be returned by the SentimentAnalyst.
We can start by defining a variable like below, this is going to help us to determine where to output folder should be creating for the training model
private static readonly bool _debugMode = true;
So firstly let's create our helper functions as below, "GetParentDirectory" function is the function which we are going to use to understand what path our trainer project is running because we want our trainer to save the model after training automatically, by that our example application "Movie Reviews" can find the trained model easily without manual copying
//Gets solution path
private static string GetParentDirectory()
{
var directoryInfo = Directory.GetParent(Directory.GetCurrentDirectory()).Parent;
if (directoryInfo?.Parent?.Parent != null)
return directoryInfo.Parent.Parent
.FullName;
return string.Empty;
}
"Train" function is the function which we are going to call for to make ML.Net use our dataset to train the machine learning model for us, by the ML.Net finishes training the model, we will display to result as shown below to understand how our model is accurate
private static void Train(SentimentAnalyst sentimentAnalyst)
{
//Starts trainer
var trainingResult = sentimentAnalyst.Train();
//Displays results of the training
Console.WriteLine("===============================================");
Console.WriteLine("Accuracy:{0}", trainingResult.Accuracy);
Console.WriteLine("AreaUnderRocCurve:{0}", trainingResult.AreaUnderRocCurve);
Console.WriteLine("F1Score:{0}", trainingResult.F1Score);
Console.WriteLine("===============================================");
}
"TrainMultiple" function is an optional function which is possible to use all learning model like LbfgsLogisticRegression, SgdCalibrated, SdcaLogisticRegression, AveragedPerceptron, LinearSvm at once and displays results, by using this function we can understand that which learning model is fitting better to our dataset.
private static void TrainMultiple(SentimentAnalyst sentimentAnalyst)
{
Console.WriteLine("Multiple Training");
//Starts trainer
var trainingResults = sentimentAnalyst.TrainMultiple();
//Displays results of the training
Console.WriteLine(
"*************************************************************************************************************");
Console.WriteLine("* Training Results ");
Console.WriteLine(
"*------------------------------------------------------------------------------------------------------------");
foreach (var trainingResult in trainingResults.OrderBy(x => x.Accuracy))
{
Console.WriteLine($"* Trainer: {trainingResult.Trainer}");
Console.WriteLine(
$"* Accuracy: {trainingResult.Accuracy:0.###} - Area Under Roc Curve: ({trainingResult.AreaUnderRocCurve:#.###}) - F1 Score: ({trainingResult.F1Score:#.###})");
}
Console.WriteLine(
"*************************************************************************************************************");
}
And the below you can find the second optional function "CrossValidation" which is actually very useful to use, this function is checking our model to understand if it is over fit or not, it can be good to use it before moving the production.
private static void CrossValidation(SentimentAnalyst sentimentAnalyst)
{
Console.WriteLine("Cross Validating");
//Starts Validating
var validationResult = sentimentAnalyst.CrossValidate();
Console.WriteLine(
"*************************************************************************************************************");
Console.WriteLine("* Metrics for Cross Validation ");
Console.WriteLine(
"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"Trainer: {validationResult.Trainer}");
Console.WriteLine(
$"* Average Accuracy: {validationResult.AccuracyAverage:0.###} - Standard deviation: ({validationResult.AccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({validationResult.AccuraciesConfidenceInterval95:#.###})");
Console.WriteLine(
"*************************************************************************************************************");
}
Here you can find the main code, what we do is here we are setting dataset and model filenames and then defining new Sentiment Analyst class by passing our training data filename and the model filename which we want to our training model be saved.
After those initializations, we call "Train" function and then displays the training result to understand how the training process was
private static void Main(string[] args)
{
var trainingDataFile = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "Data", "IMDBDataset.csv");
var modelDataFile = Path.Combine(GetParentDirectory(),
_debugMode
? $@"Movie Reviews\\Movie Reviews\\bin\\{"Debug"}\\Data"
: $@"Movie Reviews\\Movie Reviews\\bin\\{"Release"}\\Data",
"model.zip");
var sentimentAnalyst = new SentimentAnalyst(trainingDataFile, modelDataFile);
Console.WriteLine("Training");
Train(sentimentAnalyst);
//If you want to see how other models perform
//TrainMultiple(sentimentAnalyst);
//If you want to validation
//CrossValidation(sentimentAnalyst);
Console.WriteLine("Competed");
Console.ReadLine();
}
The final code of the "Trainer" project can be found below.
internal class Program
{
private static readonly bool _debugMode = true;
private static void Main(string[] args)
{
var trainingDataFile = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "Data", "IMDBDataset.csv");
var modelDataFile = Path.Combine(GetParentDirectory(),
_debugMode
? $@"Movie Reviews\\Movie Reviews\\bin\\{"Debug"}\\Data"
: $@"Movie Reviews\\Movie Reviews\\bin\\{"Release"}\\Data",
"model.zip");
var sentimentAnalyst = new SentimentAnalyst(trainingDataFile, modelDataFile);
Console.WriteLine("Training");
Train(sentimentAnalyst);
//If you want to see how other models perform
//TrainMultiple(sentimentAnalyst);
//If you want to validation
//CrossValidation(sentimentAnalyst);
Console.WriteLine("Competed");
Console.ReadLine();
}
private static void Train(SentimentAnalyst sentimentAnalyst)
{
//Starts trainer
var trainingResult = sentimentAnalyst.Train();
//Displays results of the training
Console.WriteLine("===============================================");
Console.WriteLine("Accuracy:{0}", trainingResult.Accuracy);
Console.WriteLine("AreaUnderRocCurve:{0}", trainingResult.AreaUnderRocCurve);
Console.WriteLine("F1Score:{0}", trainingResult.F1Score);
Console.WriteLine("===============================================");
}
private static void TrainMultiple(SentimentAnalyst sentimentAnalyst)
{
Console.WriteLine("Multiple Training");
//Starts trainer
var trainingResults = sentimentAnalyst.TrainMultiple();
//Displays results of the training
Console.WriteLine(
"*************************************************************************************************************");
Console.WriteLine("* Training Results ");
Console.WriteLine(
"*------------------------------------------------------------------------------------------------------------");
foreach (var trainingResult in trainingResults.OrderBy(x => x.Accuracy))
{
Console.WriteLine($"* Trainer: {trainingResult.Trainer}");
Console.WriteLine(
$"* Accuracy: {trainingResult.Accuracy:0.###} - Area Under Roc Curve: ({trainingResult.AreaUnderRocCurve:#.###}) - F1 Score: ({trainingResult.F1Score:#.###})");
}
Console.WriteLine(
"*************************************************************************************************************");
}
private static void CrossValidation(SentimentAnalyst sentimentAnalyst)
{
Console.WriteLine("Cross Validating");
//Starts Validating
var validationResult = sentimentAnalyst.CrossValidate();
Console.WriteLine(
"*************************************************************************************************************");
Console.WriteLine("* Metrics for Cross Validation ");
Console.WriteLine(
"*------------------------------------------------------------------------------------------------------------");
Console.WriteLine($"Trainer: {validationResult.Trainer}");
Console.WriteLine(
$"* Average Accuracy: {validationResult.AccuracyAverage:0.###} - Standard deviation: ({validationResult.AccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({validationResult.AccuraciesConfidenceInterval95:#.###})");
Console.WriteLine(
"*************************************************************************************************************");
}
//Gets solution path
private static string GetParentDirectory()
{
var directoryInfo = Directory.GetParent(Directory.GetCurrentDirectory()).Parent;
if (directoryInfo?.Parent?.Parent != null)
return directoryInfo.Parent.Parent
.FullName;
return string.Empty;
}
}
Please continue with the Sentiment Analysis Part 3 (Interface) for "Trainer" project.