162 lines
3.2 KiB
Go
162 lines
3.2 KiB
Go
package handler
|
|
|
|
import (
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/gofiber/fiber/v2/utils"
|
|
"gorm.io/gorm"
|
|
"rul.sh/furina-id/database"
|
|
"rul.sh/furina-id/models"
|
|
"rul.sh/furina-id/services"
|
|
)
|
|
|
|
var initialChat = models.Chat{
|
|
Role: "model",
|
|
Content: "Ah, finally someone worthy of my time! What shall we do today?",
|
|
}
|
|
|
|
func getAll(c *fiber.Ctx) error {
|
|
sessionId := getSessionId(c)
|
|
if sessionId == 0 {
|
|
return c.JSON([]models.Chat{initialChat})
|
|
}
|
|
|
|
fmt.Println(sessionId)
|
|
|
|
db := database.Get()
|
|
var chats []models.Chat
|
|
|
|
if result := db.Where("session_id = ?", sessionId).Order("id DESC").Limit(10).Find(&chats); result.Error != nil {
|
|
return c.Status(500).JSON(fiber.Map{"message": result.Error.Error()})
|
|
}
|
|
|
|
result := []fiber.Map{}
|
|
for _, chat := range chats {
|
|
result = append(result, fiber.Map{"id": chat.ID, "role": chat.Role, "content": chat.Content})
|
|
}
|
|
|
|
result = append(result, fiber.Map{
|
|
"id": initialChat.ID,
|
|
"role": initialChat.Role,
|
|
"content": initialChat.Content,
|
|
})
|
|
|
|
return c.JSON(result)
|
|
}
|
|
|
|
type CreateChatPayload struct {
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
func create(c *fiber.Ctx) error {
|
|
sessionId := getSessionId(c)
|
|
db := database.Get()
|
|
|
|
if sessionId == 0 {
|
|
var err error
|
|
sessionId, err = createSession(c)
|
|
if err != nil {
|
|
return c.Status(500).JSON(fiber.Map{"message": err.Error()})
|
|
}
|
|
}
|
|
|
|
if count := getTotalChatsLastHour(); count > 60 {
|
|
return c.Status(500).JSON(fiber.Map{"message": "Too many chats. Please try again later."})
|
|
}
|
|
|
|
data := CreateChatPayload{}
|
|
if err := c.BodyParser(&data); err != nil {
|
|
return c.Status(500).JSON(fiber.Map{"message": err.Error()})
|
|
}
|
|
|
|
msg := models.Chat{
|
|
SessionID: sessionId,
|
|
Role: "user",
|
|
Content: data.Message,
|
|
}
|
|
|
|
history := []models.Chat{}
|
|
db.Where("session_id = ?", sessionId).Order("id DESC").Limit(10).Find(&history)
|
|
history = append([]models.Chat{msg}, history...)
|
|
|
|
err := db.Transaction(func(tx *gorm.DB) error {
|
|
resp, err := services.GenerateAiChat(&history, msg.Content)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if r := tx.Create(&msg); r.Error != nil {
|
|
return r.Error
|
|
}
|
|
|
|
resp.SessionID = sessionId
|
|
if r := tx.Create(&resp); r.Error != nil {
|
|
return r.Error
|
|
}
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return c.Status(500).JSON(fiber.Map{"message": err.Error()})
|
|
}
|
|
|
|
return c.JSON(msg)
|
|
}
|
|
|
|
func HandleChat(router fiber.Router) {
|
|
router.Get("/chats", getAll)
|
|
router.Post("/chats", create)
|
|
}
|
|
|
|
func getSessionId(c *fiber.Ctx) uint {
|
|
uuid := c.Cookies("session_id")
|
|
|
|
if uuid == "" {
|
|
return 0
|
|
}
|
|
|
|
db := database.Get()
|
|
var session models.Session
|
|
db.Where("uuid = ?", uuid).First(&session)
|
|
|
|
return session.ID
|
|
}
|
|
|
|
func createSession(c *fiber.Ctx) (uint, error) {
|
|
uuid := utils.UUIDv4()
|
|
|
|
db := database.Get()
|
|
session := models.Session{
|
|
UUID: uuid,
|
|
}
|
|
|
|
if result := db.Create(&session); result.Error != nil {
|
|
return 0, result.Error
|
|
}
|
|
|
|
c.Cookie(&fiber.Cookie{
|
|
Name: "session_id",
|
|
Value: uuid,
|
|
Expires: time.Now().Add(24 * time.Hour),
|
|
HTTPOnly: true,
|
|
})
|
|
|
|
return session.ID, nil
|
|
}
|
|
|
|
func getTotalChatsLastHour() int64 {
|
|
var count int64
|
|
oneHourAgo := time.Now().Add(-1 * time.Hour)
|
|
|
|
r := database.Get().Model(&models.Chat{}).
|
|
Where("created_at >= ?", oneHourAgo).
|
|
Count(&count)
|
|
|
|
if r.Error != nil {
|
|
return 0
|
|
}
|
|
|
|
return count
|
|
}
|