Skip to content

Commit

Permalink
feat: machine learning with .net, question & response basic
Browse files Browse the repository at this point in the history
  • Loading branch information
mtai0524 committed Nov 26, 2024
1 parent d1030cc commit e93ff3d
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 65 deletions.
43 changes: 0 additions & 43 deletions NotaionWebApp/Notaion.Domain/Models/ChatModelTrainer.cs

This file was deleted.

7 changes: 2 additions & 5 deletions NotaionWebApp/Notaion.Infrastructure/DependencyInjection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@
using Notaion.Infrastructure.Persistence;
using Notaion.Infrastructure.Repositories;
using Notaion.Infrastructure.Services;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace Notaion.Infrastructure
{
Expand Down Expand Up @@ -47,6 +42,8 @@ public static IServiceCollection AddInfrastructure(this IServiceCollection servi

services.AddScoped<ICloudinaryService, CloudinaryService>();

services.AddSingleton<ChatModelTrainer>();

return services;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
<PackageReference Include="Microsoft.AspNetCore.Identity" Version="2.2.0" />
<PackageReference Include="Microsoft.AspNetCore.Mvc.Core" Version="2.2.5" />
<PackageReference Include="Microsoft.EntityFrameworkCore.SqlServer" Version="7.0.20" />
<PackageReference Include="Microsoft.ML" Version="4.0.0" />
<PackageReference Include="OpenAI" Version="2.0.0" />
</ItemGroup>

Expand Down
122 changes: 122 additions & 0 deletions NotaionWebApp/Notaion.Infrastructure/Repositories/ChatModelTrainer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
using Microsoft.ML;
using Microsoft.ML.Data;

namespace Notaion.Infrastructure.Repositories
{
// Định nghĩa lớp chứa dữ liệu huấn luyện
public class ChatData
{
[LoadColumn(0)]
public string Question { get; set; }

[LoadColumn(1)]
public string Response { get; set; }
}

// Lớp huấn luyện mô hình
public class ChatModelTrainer
{
private readonly string ModelPath = Path.Combine(Directory.GetCurrentDirectory(), "chatbot_model.zip");

// Huấn luyện mô hình từ file CSV
public async Task TrainModelFromCsvAsync(string filePath)
{
try
{
var mlContext = new MLContext();

// Đọc dữ liệu từ tệp CSV
var data = mlContext.Data.LoadFromTextFile<ChatData>(filePath, separatorChar: ',', hasHeader: true);

var preview = data.Preview(10); // Hiển thị 10 dòng đầu tiên của dữ liệu
foreach (var row in preview.RowView)
{
Console.WriteLine($"Question: {row.Values[0].Value}, Response: {row.Values[1].Value}");
}

// Xây dựng pipeline huấn luyện
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", nameof(ChatData.Question)) // Tính năng từ câu hỏi
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label", nameof(ChatData.Response))) // Câu trả lời là nhãn
.Append(mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy("Label", "Features")) // Huấn luyện phân loại
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel", "PredictedLabel"));

// Huấn luyện mô hình
var model = pipeline.Fit(data);

// Kiểm tra kết quả trước khi lưu mô hình
var transformer = model.Transform(data);
var predictions = mlContext.Data.CreateEnumerable<ChatPrediction>(transformer, reuseRowObject: false).ToList();

foreach (var prediction in predictions)
{
if (string.IsNullOrWhiteSpace(prediction.PredictedLabel))
{
Console.WriteLine("Dự đoán bị thiếu cho câu hỏi: " + prediction.Question);
}
else
{
Console.WriteLine($"Câu hỏi: {prediction.Question}, Dự đoán câu trả lời: {prediction.PredictedLabel}");
}
}

// Lưu mô hình vào file
mlContext.Model.Save(model, data.Schema, ModelPath);
Console.WriteLine($"Mô hình đã được lưu tại: {ModelPath}");
}
catch (Exception ex)
{
Console.WriteLine($"Lỗi khi huấn luyện mô hình: {ex.Message}");
throw;
}
}

// Dự đoán câu trả lời từ câu hỏi của người dùng
public async Task<string> PredictResponseAsync(string userMessage)
{
try
{
var mlContext = new MLContext();

// Kiểm tra xem mô hình có tồn tại không
if (!File.Exists(ModelPath))
{
throw new FileNotFoundException("Mô hình không tồn tại tại: " + ModelPath);
}

// Load mô hình đã huấn luyện
var model = mlContext.Model.Load(ModelPath, out var modelInputSchema);

// Tạo prediction engine
var predictionEngine = mlContext.Model.CreatePredictionEngine<ChatData, ChatPrediction>(model);

// Sử dụng mô hình thực tế để dự đoán phản hồi
var prediction = predictionEngine.Predict(new ChatData { Question = userMessage });

// Kiểm tra kết quả dự đoán từ mô hình
if (string.IsNullOrWhiteSpace(prediction?.PredictedLabel))
{
return "Xin lỗi, tôi không thể trả lời câu hỏi này.";
}

// Trả về kết quả dự đoán
return prediction.PredictedLabel;
}
catch (FileNotFoundException ex)
{
return "Không tìm thấy mô hình: " + ex.Message;
}
catch (Exception ex)
{
return "Đã xảy ra lỗi khi dự đoán: " + ex.Message;
}
}
}


// Dự đoán phản hồi cho một câu hỏi
public class ChatPrediction
{
public string Question { get; set; } // Câu hỏi gốc
public string PredictedLabel { get; set; } // Câu trả lời dự đoán
}
}
62 changes: 50 additions & 12 deletions NotaionWebApp/Notaion/Controllers/ChatController.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,12 @@
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.SignalR;
using Notaion.Infrastructure.Context;
using Notaion.Domain.Entities;
using Notaion.Hubs;
using Microsoft.EntityFrameworkCore;
using Notaion.Domain.Models;
using Notaion.Application.DTOs.Chats;
using Notaion.Application.Common.Helpers;
using AutoMapper;
using Notaion.Application.Services;
using Notaion.Application.Interfaces.Services;
using Microsoft.AspNetCore.Authorization;
using System;
using Notaion.Domain.Interfaces;
using Notaion.Hubs;
using Notaion.Infrastructure.Context;
using Notaion.Infrastructure.Repositories;

namespace Notaion.Controllers
{
Expand All @@ -23,11 +17,55 @@ public class ChatController : ControllerBase
private readonly ApplicationDbContext _context;
private readonly IHubContext<ChatHub> _hubContext;
private readonly IChatService chatService;
public ChatController(ApplicationDbContext context, IHubContext<ChatHub> hubContext, IChatService chatService)
private readonly ChatModelTrainer _chatModelTrainer;
public ChatController(ApplicationDbContext context, IHubContext<ChatHub> hubContext, IChatService chatService, ChatModelTrainer chatModelTrainer)
{
_context = context;
this.chatService = chatService;
_hubContext = hubContext;
_chatModelTrainer = chatModelTrainer;
}

[HttpPost("train")]
public async Task<IActionResult> TrainChatbotModel()
{
try
{
// Sử dụng đường dẫn cố định đến file responses.csv
string filePath = Path.Combine(Directory.GetCurrentDirectory(), "responses.csv");

// Kiểm tra xem file có tồn tại không
if (!System.IO.File.Exists(filePath))
{
return BadRequest($"File không tồn tại tại đường dẫn: {filePath}");
}
var modelTrainer = new ChatModelTrainer();
// Huấn luyện mô hình từ file
await modelTrainer.TrainModelFromCsvAsync(filePath);
return Ok("Mô hình đã được huấn luyện thành công.");
}
catch (Exception ex)
{
return BadRequest($"Đã xảy ra lỗi khi huấn luyện mô hình: {ex.Message}");
}
}



// Dự đoán câu trả lời từ câu hỏi của người dùng
[HttpPost("predict")]
public async Task<IActionResult> PredictResponse([FromBody] string userMessage)
{
try
{
// Gọi phương thức bất đồng bộ
var response = await _chatModelTrainer.PredictResponseAsync(userMessage);
return Ok(new { Response = response });
}
catch (Exception ex)
{
return BadRequest($"Đã xảy ra lỗi trong quá trình dự đoán: {ex.Message}");
}
}

//[HttpGet("test-genaric-repo")]
Expand Down
1 change: 1 addition & 0 deletions NotaionWebApp/Notaion/Notaion.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
</PackageReference>
<PackageReference Include="Microsoft.Extensions.Configuration.Binder" Version="7.0.4" />
<PackageReference Include="Microsoft.Identity.Web" Version="1.16.0" />
<PackageReference Include="Microsoft.ML" Version="4.0.0" />
<PackageReference Include="Microsoft.VisualStudio.Azure.Containers.Tools.Targets" Version="1.19.5" />
<PackageReference Include="Microsoft.VisualStudio.Web.CodeGeneration.Design" Version="7.0.12" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.5.0" />
Expand Down
Binary file added NotaionWebApp/Notaion/chatbot_model.zip
Binary file not shown.
8 changes: 3 additions & 5 deletions NotaionWebApp/Notaion/responses.csv
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
Question,Response
"Xin chào","Chào bạn, tôi có thể giúp gì?"
"hello","hi"
"Thời tiết hôm nay thế nào?","Thời tiết hôm nay nắng đẹp."
"Bạn tên gì?","Tôi là Chatbot."
"Tôi muốn học lập trình","Hãy bắt đầu với những bài học cơ bản."
"Giá vàng hôm nay là bao nhiêu?","Hiện tại tôi chưa có thông tin về giá vàng."

"Bạn tên gì?","Tôi là một chatbot."
"Chào bạn!","Chào bạn!"

0 comments on commit e93ff3d

Please sign in to comment.