create project
This commit is contained in:
398
application/services/ai_service.go
Normal file
398
application/services/ai_service.go
Normal file
@@ -0,0 +1,398 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/pkg/ai"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AIService struct {
|
||||
db *gorm.DB
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
|
||||
return &AIService{
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
type CreateAIConfigRequest struct {
|
||||
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
|
||||
Name string `json:"name" binding:"required,min=1,max=100"`
|
||||
Provider string `json:"provider" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required,url"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
Model models.ModelField `json:"model" binding:"required"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
QueryEndpoint string `json:"query_endpoint"`
|
||||
Priority int `json:"priority"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
Settings string `json:"settings"`
|
||||
}
|
||||
|
||||
type UpdateAIConfigRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,min=1,max=100"`
|
||||
Provider string `json:"provider"`
|
||||
BaseURL string `json:"base_url" binding:"omitempty,url"`
|
||||
APIKey string `json:"api_key"`
|
||||
Model *models.ModelField `json:"model"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
QueryEndpoint string `json:"query_endpoint"`
|
||||
Priority *int `json:"priority"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Settings string `json:"settings"`
|
||||
}
|
||||
|
||||
type TestConnectionRequest struct {
|
||||
BaseURL string `json:"base_url" binding:"required,url"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
Model models.ModelField `json:"model" binding:"required"`
|
||||
Provider string `json:"provider"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
|
||||
// 根据 provider 和 service_type 自动设置 endpoint
|
||||
endpoint := req.Endpoint
|
||||
queryEndpoint := req.QueryEndpoint
|
||||
|
||||
if endpoint == "" {
|
||||
switch req.Provider {
|
||||
case "gemini", "google":
|
||||
if req.ServiceType == "text" {
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
} else if req.ServiceType == "image" {
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
}
|
||||
case "openai":
|
||||
if req.ServiceType == "text" {
|
||||
endpoint = "/chat/completions"
|
||||
} else if req.ServiceType == "image" {
|
||||
endpoint = "/images/generations"
|
||||
} else if req.ServiceType == "video" {
|
||||
endpoint = "/videos"
|
||||
if queryEndpoint == "" {
|
||||
queryEndpoint = "/videos/{taskId}"
|
||||
}
|
||||
}
|
||||
case "chatfire":
|
||||
if req.ServiceType == "text" {
|
||||
endpoint = "/chat/completions"
|
||||
} else if req.ServiceType == "image" {
|
||||
endpoint = "/images/generations"
|
||||
} else if req.ServiceType == "video" {
|
||||
endpoint = "/video/generations"
|
||||
if queryEndpoint == "" {
|
||||
queryEndpoint = "/video/task/{taskId}"
|
||||
}
|
||||
}
|
||||
case "doubao", "volcengine", "volces":
|
||||
if req.ServiceType == "video" {
|
||||
endpoint = "/contents/generations/tasks"
|
||||
if queryEndpoint == "" {
|
||||
queryEndpoint = "/generations/tasks/{taskId}"
|
||||
}
|
||||
}
|
||||
default:
|
||||
// 默认使用 OpenAI 格式
|
||||
if req.ServiceType == "text" {
|
||||
endpoint = "/chat/completions"
|
||||
} else if req.ServiceType == "image" {
|
||||
endpoint = "/images/generations"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config := &models.AIServiceConfig{
|
||||
ServiceType: req.ServiceType,
|
||||
Name: req.Name,
|
||||
Provider: req.Provider,
|
||||
BaseURL: req.BaseURL,
|
||||
APIKey: req.APIKey,
|
||||
Model: req.Model,
|
||||
Endpoint: endpoint,
|
||||
QueryEndpoint: queryEndpoint,
|
||||
Priority: req.Priority,
|
||||
IsDefault: req.IsDefault,
|
||||
IsActive: true,
|
||||
Settings: req.Settings,
|
||||
}
|
||||
|
||||
if err := s.db.Create(config).Error; err != nil {
|
||||
s.log.Errorw("Failed to create AI config", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (s *AIService) GetConfig(configID uint) (*models.AIServiceConfig, error) {
|
||||
var config models.AIServiceConfig
|
||||
err := s.db.Where("id = ? ", configID).First(&config).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("config not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (s *AIService) ListConfigs(serviceType string) ([]models.AIServiceConfig, error) {
|
||||
var configs []models.AIServiceConfig
|
||||
query := s.db
|
||||
|
||||
if serviceType != "" {
|
||||
query = query.Where("service_type = ?", serviceType)
|
||||
}
|
||||
|
||||
err := query.Order("priority DESC, created_at DESC").Find(&configs).Error
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to list AI configs", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*models.AIServiceConfig, error) {
|
||||
var config models.AIServiceConfig
|
||||
if err := s.db.Where("id = ? ", configID).First(&config).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("config not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tx := s.db.Begin()
|
||||
|
||||
// 不再需要is_default独占逻辑
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
if req.Name != "" {
|
||||
updates["name"] = req.Name
|
||||
}
|
||||
if req.Provider != "" {
|
||||
updates["provider"] = req.Provider
|
||||
}
|
||||
if req.BaseURL != "" {
|
||||
updates["base_url"] = req.BaseURL
|
||||
}
|
||||
if req.APIKey != "" {
|
||||
updates["api_key"] = req.APIKey
|
||||
}
|
||||
if req.Model != nil && len(*req.Model) > 0 {
|
||||
updates["model"] = *req.Model
|
||||
}
|
||||
if req.Priority != nil {
|
||||
updates["priority"] = *req.Priority
|
||||
}
|
||||
|
||||
// 如果提供了 provider,根据 provider 和 service_type 自动设置 endpoint
|
||||
if req.Provider != "" && req.Endpoint == "" {
|
||||
provider := req.Provider
|
||||
serviceType := config.ServiceType
|
||||
|
||||
switch provider {
|
||||
case "gemini", "google":
|
||||
if serviceType == "text" || serviceType == "image" {
|
||||
updates["endpoint"] = "/v1beta/models/{model}:generateContent"
|
||||
}
|
||||
case "openai":
|
||||
if serviceType == "text" {
|
||||
updates["endpoint"] = "/chat/completions"
|
||||
} else if serviceType == "image" {
|
||||
updates["endpoint"] = "/images/generations"
|
||||
} else if serviceType == "video" {
|
||||
updates["endpoint"] = "/videos"
|
||||
updates["query_endpoint"] = "/videos/{taskId}"
|
||||
}
|
||||
case "chatfire":
|
||||
if serviceType == "text" {
|
||||
updates["endpoint"] = "/chat/completions"
|
||||
} else if serviceType == "image" {
|
||||
updates["endpoint"] = "/images/generations"
|
||||
} else if serviceType == "video" {
|
||||
updates["endpoint"] = "/video/generations"
|
||||
updates["query_endpoint"] = "/video/task/{taskId}"
|
||||
}
|
||||
}
|
||||
} else if req.Endpoint != "" {
|
||||
updates["endpoint"] = req.Endpoint
|
||||
}
|
||||
|
||||
// 允许清空query_endpoint,所以不检查是否为空
|
||||
updates["query_endpoint"] = req.QueryEndpoint
|
||||
if req.Settings != "" {
|
||||
updates["settings"] = req.Settings
|
||||
}
|
||||
updates["is_default"] = req.IsDefault
|
||||
updates["is_active"] = req.IsActive
|
||||
|
||||
if err := tx.Model(&config).Updates(updates).Error; err != nil {
|
||||
tx.Rollback()
|
||||
s.log.Errorw("Failed to update AI config", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.log.Infow("AI config updated", "config_id", configID)
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (s *AIService) DeleteConfig(configID uint) error {
|
||||
result := s.db.Where("id = ? ", configID).Delete(&models.AIServiceConfig{})
|
||||
|
||||
if result.Error != nil {
|
||||
s.log.Errorw("Failed to delete AI config", "error", result.Error)
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return errors.New("config not found")
|
||||
}
|
||||
|
||||
s.log.Infow("AI config deleted", "config_id", configID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
|
||||
s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model))
|
||||
|
||||
// 使用第一个模型进行测试
|
||||
model := ""
|
||||
if len(req.Model) > 0 {
|
||||
model = req.Model[0]
|
||||
}
|
||||
s.log.Infow("Using model for test", "model", model, "provider", req.Provider)
|
||||
|
||||
// 根据 provider 参数选择客户端
|
||||
var client ai.AIClient
|
||||
var endpoint string
|
||||
|
||||
switch req.Provider {
|
||||
case "gemini", "google":
|
||||
// Gemini
|
||||
s.log.Infow("Using Gemini client", "baseURL", req.BaseURL)
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||
case "openai", "chatfire":
|
||||
// OpenAI 格式(包括 chatfire 等)
|
||||
s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider)
|
||||
endpoint = req.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = "/chat/completions"
|
||||
}
|
||||
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||
default:
|
||||
// 默认使用 OpenAI 格式
|
||||
s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL)
|
||||
endpoint = req.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = "/chat/completions"
|
||||
}
|
||||
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||
}
|
||||
|
||||
s.log.Infow("Calling TestConnection on client", "endpoint", endpoint)
|
||||
err := client.TestConnection()
|
||||
if err != nil {
|
||||
s.log.Errorw("TestConnection failed", "error", err)
|
||||
} else {
|
||||
s.log.Infow("TestConnection succeeded")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
|
||||
var config models.AIServiceConfig
|
||||
// 按优先级降序获取第一个配置
|
||||
err := s.db.Where("service_type = ?", serviceType).
|
||||
Order("priority DESC, created_at DESC").
|
||||
First(&config).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("no config found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// GetConfigForModel 根据服务类型和模型名称获取优先级最高的配置
|
||||
func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*models.AIServiceConfig, error) {
|
||||
var configs []models.AIServiceConfig
|
||||
err := s.db.Where("service_type = ?", serviceType).
|
||||
Order("priority DESC, created_at DESC").
|
||||
Find(&configs).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 查找包含指定模型的配置
|
||||
for _, config := range configs {
|
||||
for _, model := range config.Model {
|
||||
if model == modelName {
|
||||
return &config, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("no config found for model: " + modelName)
|
||||
}
|
||||
|
||||
func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) {
|
||||
config, err := s.GetDefaultConfig(serviceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用第一个模型
|
||||
model := ""
|
||||
if len(config.Model) > 0 {
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
// 使用数据库配置中的 endpoint,如果为空则根据 provider 设置默认值
|
||||
endpoint := config.Endpoint
|
||||
if endpoint == "" {
|
||||
switch config.Provider {
|
||||
case "gemini", "google":
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
default:
|
||||
endpoint = "/chat/completions"
|
||||
}
|
||||
}
|
||||
|
||||
// 根据 provider 创建对应的客户端
|
||||
switch config.Provider {
|
||||
case "gemini", "google":
|
||||
return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
default:
|
||||
// openai, chatfire 等其他厂商都使用 OpenAI 格式
|
||||
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {
|
||||
client, err := s.GetAIClient("text")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get AI client: %w", err)
|
||||
}
|
||||
|
||||
return client.GenerateText(prompt, systemPrompt, options...)
|
||||
}
|
||||
287
application/services/asset_service.go
Normal file
287
application/services/asset_service.go
Normal file
@@ -0,0 +1,287 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
models "github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type AssetService struct {
|
||||
db *gorm.DB
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewAssetService(db *gorm.DB, log *logger.Logger) *AssetService {
|
||||
return &AssetService{
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
type CreateAssetRequest struct {
|
||||
DramaID *string `json:"drama_id"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Description *string `json:"description"`
|
||||
Type models.AssetType `json:"type" binding:"required"`
|
||||
Category *string `json:"category"`
|
||||
URL string `json:"url" binding:"required"`
|
||||
ThumbnailURL *string `json:"thumbnail_url"`
|
||||
LocalPath *string `json:"local_path"`
|
||||
FileSize *int64 `json:"file_size"`
|
||||
MimeType *string `json:"mime_type"`
|
||||
Width *int `json:"width"`
|
||||
Height *int `json:"height"`
|
||||
Duration *int `json:"duration"`
|
||||
Format *string `json:"format"`
|
||||
ImageGenID *uint `json:"image_gen_id"`
|
||||
VideoGenID *uint `json:"video_gen_id"`
|
||||
TagIDs []uint `json:"tag_ids"`
|
||||
}
|
||||
|
||||
type UpdateAssetRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Description *string `json:"description"`
|
||||
Category *string `json:"category"`
|
||||
ThumbnailURL *string `json:"thumbnail_url"`
|
||||
TagIDs []uint `json:"tag_ids"`
|
||||
IsFavorite *bool `json:"is_favorite"`
|
||||
}
|
||||
|
||||
type ListAssetsRequest struct {
|
||||
DramaID *string `json:"drama_id"`
|
||||
EpisodeID *uint `json:"episode_id"`
|
||||
StoryboardID *uint `json:"storyboard_id"`
|
||||
Type *models.AssetType `json:"type"`
|
||||
Category string `json:"category"`
|
||||
TagIDs []uint `json:"tag_ids"`
|
||||
IsFavorite *bool `json:"is_favorite"`
|
||||
Search string `json:"search"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
}
|
||||
|
||||
func (s *AssetService) CreateAsset(req *CreateAssetRequest) (*models.Asset, error) {
|
||||
var dramaID *uint
|
||||
if req.DramaID != nil && *req.DramaID != "" {
|
||||
id, err := strconv.ParseUint(*req.DramaID, 10, 32)
|
||||
if err == nil {
|
||||
uid := uint(id)
|
||||
dramaID = &uid
|
||||
}
|
||||
}
|
||||
|
||||
if dramaID != nil {
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ?", *dramaID).First(&drama).Error; err != nil {
|
||||
return nil, fmt.Errorf("drama not found")
|
||||
}
|
||||
}
|
||||
|
||||
asset := &models.Asset{
|
||||
DramaID: dramaID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Type: req.Type,
|
||||
Category: req.Category,
|
||||
URL: req.URL,
|
||||
ThumbnailURL: req.ThumbnailURL,
|
||||
LocalPath: req.LocalPath,
|
||||
FileSize: req.FileSize,
|
||||
MimeType: req.MimeType,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
Duration: req.Duration,
|
||||
Format: req.Format,
|
||||
ImageGenID: req.ImageGenID,
|
||||
VideoGenID: req.VideoGenID,
|
||||
}
|
||||
|
||||
if err := s.db.Create(asset).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create asset: %w", err)
|
||||
}
|
||||
|
||||
return asset, nil
|
||||
}
|
||||
|
||||
func (s *AssetService) UpdateAsset(assetID uint, req *UpdateAssetRequest) (*models.Asset, error) {
|
||||
var asset models.Asset
|
||||
if err := s.db.Where("id = ?", assetID).First(&asset).Error; err != nil {
|
||||
return nil, fmt.Errorf("asset not found")
|
||||
}
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
if req.Name != nil {
|
||||
updates["name"] = *req.Name
|
||||
}
|
||||
if req.Description != nil {
|
||||
updates["description"] = *req.Description
|
||||
}
|
||||
if req.Category != nil {
|
||||
updates["category"] = *req.Category
|
||||
}
|
||||
if req.ThumbnailURL != nil {
|
||||
updates["thumbnail_url"] = *req.ThumbnailURL
|
||||
}
|
||||
if req.IsFavorite != nil {
|
||||
updates["is_favorite"] = *req.IsFavorite
|
||||
}
|
||||
|
||||
if len(updates) > 0 {
|
||||
if err := s.db.Model(&asset).Updates(updates).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to update asset: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.db.First(&asset, assetID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &asset, nil
|
||||
}
|
||||
|
||||
func (s *AssetService) GetAsset(assetID uint) (*models.Asset, error) {
|
||||
var asset models.Asset
|
||||
if err := s.db.Where("id = ? ", assetID).First(&asset).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.db.Model(&asset).UpdateColumn("view_count", gorm.Expr("view_count + ?", 1))
|
||||
|
||||
return &asset, nil
|
||||
}
|
||||
|
||||
func (s *AssetService) ListAssets(req *ListAssetsRequest) ([]models.Asset, int64, error) {
|
||||
query := s.db.Model(&models.Asset{})
|
||||
|
||||
if req.DramaID != nil {
|
||||
var dramaID uint64
|
||||
dramaID, _ = strconv.ParseUint(*req.DramaID, 10, 32)
|
||||
query = query.Where("drama_id = ?", uint(dramaID))
|
||||
}
|
||||
|
||||
if req.EpisodeID != nil {
|
||||
query = query.Where("episode_id = ?", *req.EpisodeID)
|
||||
}
|
||||
|
||||
if req.StoryboardID != nil {
|
||||
query = query.Where("storyboard_id = ?", *req.StoryboardID)
|
||||
}
|
||||
|
||||
if req.Type != nil {
|
||||
query = query.Where("type = ?", *req.Type)
|
||||
}
|
||||
|
||||
if req.Category != "" {
|
||||
query = query.Where("category = ?", req.Category)
|
||||
}
|
||||
|
||||
if req.IsFavorite != nil {
|
||||
query = query.Where("is_favorite = ?", *req.IsFavorite)
|
||||
}
|
||||
|
||||
if req.Search != "" {
|
||||
searchTerm := "%" + strings.ToLower(req.Search) + "%"
|
||||
query = query.Where("LOWER(name) LIKE ? OR LOWER(description) LIKE ?", searchTerm, searchTerm)
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var assets []models.Asset
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
if err := query.Order("created_at DESC").
|
||||
Offset(offset).Limit(req.PageSize).Find(&assets).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return assets, total, nil
|
||||
}
|
||||
|
||||
func (s *AssetService) DeleteAsset(assetID uint) error {
|
||||
result := s.db.Where("id = ?", assetID).Delete(&models.Asset{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("asset not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AssetService) ImportFromImageGen(imageGenID uint) (*models.Asset, error) {
|
||||
var imageGen models.ImageGeneration
|
||||
if err := s.db.Where("id = ? ", imageGenID).First(&imageGen).Error; err != nil {
|
||||
return nil, fmt.Errorf("image generation not found")
|
||||
}
|
||||
|
||||
if imageGen.Status != models.ImageStatusCompleted || imageGen.ImageURL == nil {
|
||||
return nil, fmt.Errorf("image is not ready")
|
||||
}
|
||||
|
||||
dramaID := imageGen.DramaID
|
||||
asset := &models.Asset{
|
||||
Name: fmt.Sprintf("Image_%d", imageGen.ID),
|
||||
Type: models.AssetTypeImage,
|
||||
URL: *imageGen.ImageURL,
|
||||
DramaID: &dramaID,
|
||||
ImageGenID: &imageGenID,
|
||||
Width: imageGen.Width,
|
||||
Height: imageGen.Height,
|
||||
}
|
||||
|
||||
if err := s.db.Create(asset).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create asset: %w", err)
|
||||
}
|
||||
|
||||
return asset, nil
|
||||
}
|
||||
|
||||
func (s *AssetService) ImportFromVideoGen(videoGenID uint) (*models.Asset, error) {
|
||||
var videoGen models.VideoGeneration
|
||||
if err := s.db.Preload("Storyboard.Episode").Where("id = ? ", videoGenID).First(&videoGen).Error; err != nil {
|
||||
return nil, fmt.Errorf("video generation not found")
|
||||
}
|
||||
|
||||
if videoGen.Status != models.VideoStatusCompleted || videoGen.VideoURL == nil {
|
||||
return nil, fmt.Errorf("video is not ready")
|
||||
}
|
||||
|
||||
dramaID := videoGen.DramaID
|
||||
|
||||
var episodeID *uint
|
||||
var storyboardNum *int
|
||||
if videoGen.Storyboard != nil {
|
||||
episodeID = &videoGen.Storyboard.Episode.ID
|
||||
storyboardNum = &videoGen.Storyboard.StoryboardNumber
|
||||
}
|
||||
|
||||
asset := &models.Asset{
|
||||
Name: fmt.Sprintf("Video_%d", videoGen.ID),
|
||||
Type: models.AssetTypeVideo,
|
||||
URL: *videoGen.VideoURL,
|
||||
DramaID: &dramaID,
|
||||
EpisodeID: episodeID,
|
||||
StoryboardID: videoGen.StoryboardID,
|
||||
StoryboardNum: storyboardNum,
|
||||
VideoGenID: &videoGenID,
|
||||
Duration: videoGen.Duration,
|
||||
Width: videoGen.Width,
|
||||
Height: videoGen.Height,
|
||||
}
|
||||
|
||||
if videoGen.FirstFrameURL != nil {
|
||||
asset.ThumbnailURL = videoGen.FirstFrameURL
|
||||
}
|
||||
|
||||
if err := s.db.Create(asset).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create asset: %w", err)
|
||||
}
|
||||
|
||||
return asset, nil
|
||||
}
|
||||
473
application/services/character_library_service.go
Normal file
473
application/services/character_library_service.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
models "github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type CharacterLibraryService struct {
|
||||
db *gorm.DB
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewCharacterLibraryService(db *gorm.DB, log *logger.Logger) *CharacterLibraryService {
|
||||
return &CharacterLibraryService{
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
type CreateLibraryItemRequest struct {
|
||||
Name string `json:"name" binding:"required,min=1,max=100"`
|
||||
Category *string `json:"category"`
|
||||
ImageURL string `json:"image_url" binding:"required"`
|
||||
Description *string `json:"description"`
|
||||
Tags *string `json:"tags"`
|
||||
SourceType string `json:"source_type"`
|
||||
}
|
||||
|
||||
type CharacterLibraryQuery struct {
|
||||
Page int `form:"page,default=1"`
|
||||
PageSize int `form:"page_size,default=20"`
|
||||
Category string `form:"category"`
|
||||
SourceType string `form:"source_type"`
|
||||
Keyword string `form:"keyword"`
|
||||
}
|
||||
|
||||
// ListLibraryItems 获取用户角色库列表
|
||||
func (s *CharacterLibraryService) ListLibraryItems(query *CharacterLibraryQuery) ([]models.CharacterLibrary, int64, error) {
|
||||
var items []models.CharacterLibrary
|
||||
var total int64
|
||||
|
||||
db := s.db.Model(&models.CharacterLibrary{})
|
||||
|
||||
// 筛选条件
|
||||
if query.Category != "" {
|
||||
db = db.Where("category = ?", query.Category)
|
||||
}
|
||||
|
||||
if query.SourceType != "" {
|
||||
db = db.Where("source_type = ?", query.SourceType)
|
||||
}
|
||||
|
||||
if query.Keyword != "" {
|
||||
db = db.Where("name LIKE ? OR description LIKE ?", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
s.log.Errorw("Failed to count character library", "error", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
err := db.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(query.PageSize).
|
||||
Find(&items).Error
|
||||
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to list character library", "error", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return items, total, nil
|
||||
}
|
||||
|
||||
// CreateLibraryItem 添加到角色库
|
||||
func (s *CharacterLibraryService) CreateLibraryItem(req *CreateLibraryItemRequest) (*models.CharacterLibrary, error) {
|
||||
sourceType := req.SourceType
|
||||
if sourceType == "" {
|
||||
sourceType = "generated"
|
||||
}
|
||||
|
||||
item := &models.CharacterLibrary{
|
||||
Name: req.Name,
|
||||
Category: req.Category,
|
||||
ImageURL: req.ImageURL,
|
||||
Description: req.Description,
|
||||
Tags: req.Tags,
|
||||
SourceType: sourceType,
|
||||
}
|
||||
|
||||
if err := s.db.Create(item).Error; err != nil {
|
||||
s.log.Errorw("Failed to create library item", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.log.Infow("Library item created", "item_id", item.ID)
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// GetLibraryItem 获取角色库项
|
||||
func (s *CharacterLibraryService) GetLibraryItem(itemID string) (*models.CharacterLibrary, error) {
|
||||
var item models.CharacterLibrary
|
||||
err := s.db.Where("id = ? ", itemID).First(&item).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("library item not found")
|
||||
}
|
||||
s.log.Errorw("Failed to get library item", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &item, nil
|
||||
}
|
||||
|
||||
// DeleteLibraryItem 删除角色库项
|
||||
func (s *CharacterLibraryService) DeleteLibraryItem(itemID string) error {
|
||||
result := s.db.Where("id = ? ", itemID).Delete(&models.CharacterLibrary{})
|
||||
|
||||
if result.Error != nil {
|
||||
s.log.Errorw("Failed to delete library item", "error", result.Error)
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return errors.New("library item not found")
|
||||
}
|
||||
|
||||
s.log.Infow("Library item deleted", "item_id", itemID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyLibraryItemToCharacter 将角色库形象应用到角色
|
||||
func (s *CharacterLibraryService) ApplyLibraryItemToCharacter(characterID string, libraryItemID string) error {
|
||||
// 验证角色库项存在且属于该用户
|
||||
var libraryItem models.CharacterLibrary
|
||||
if err := s.db.Where("id = ? ", libraryItemID).First(&libraryItem).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("library item not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 查找角色
|
||||
var character models.Character
|
||||
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("character not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 查询Drama验证权限
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("unauthorized")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新角色的image_url
|
||||
if err := s.db.Model(&character).Update("image_url", libraryItem.ImageURL).Error; err != nil {
|
||||
s.log.Errorw("Failed to update character image", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.log.Infow("Library item applied to character", "character_id", characterID, "library_item_id", libraryItemID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UploadCharacterImage 上传角色图片
|
||||
func (s *CharacterLibraryService) UploadCharacterImage(characterID string, imageURL string) error {
|
||||
// 查找角色
|
||||
var character models.Character
|
||||
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("character not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 查询Drama验证权限
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("unauthorized")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新图片URL
|
||||
if err := s.db.Model(&character).Update("image_url", imageURL).Error; err != nil {
|
||||
s.log.Errorw("Failed to update character image", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.log.Infow("Character image uploaded", "character_id", characterID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddCharacterToLibrary 将角色添加到角色库
|
||||
func (s *CharacterLibraryService) AddCharacterToLibrary(characterID string, category *string) (*models.CharacterLibrary, error) {
|
||||
// 查找角色
|
||||
var character models.Character
|
||||
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("character not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 查询Drama验证权限
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("unauthorized")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查是否有图片
|
||||
if character.ImageURL == nil || *character.ImageURL == "" {
|
||||
return nil, fmt.Errorf("角色还没有形象图片")
|
||||
}
|
||||
|
||||
// 创建角色库项
|
||||
charLibrary := &models.CharacterLibrary{
|
||||
Name: character.Name,
|
||||
ImageURL: *character.ImageURL,
|
||||
Description: character.Description,
|
||||
SourceType: "character",
|
||||
}
|
||||
|
||||
if err := s.db.Create(charLibrary).Error; err != nil {
|
||||
s.log.Errorw("Failed to add character to library", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.log.Infow("Character added to library", "character_id", characterID, "library_item_id", charLibrary.ID)
|
||||
return charLibrary, nil
|
||||
}
|
||||
|
||||
// DeleteCharacter 删除单个角色
|
||||
func (s *CharacterLibraryService) DeleteCharacter(characterID uint) error {
|
||||
// 查找角色
|
||||
var character models.Character
|
||||
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("character not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证权限:检查角色所属的drama是否属于当前用户
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("unauthorized")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除角色
|
||||
if err := s.db.Delete(&character).Error; err != nil {
|
||||
s.log.Errorw("Failed to delete character", "error", err, "id", characterID)
|
||||
return err
|
||||
}
|
||||
|
||||
s.log.Infow("Character deleted", "id", characterID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCharacterImage AI生成角色形象
|
||||
func (s *CharacterLibraryService) GenerateCharacterImage(characterID string, imageService *ImageGenerationService, modelName string) (*models.ImageGeneration, error) {
|
||||
// 查找角色
|
||||
var character models.Character
|
||||
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("character not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 查询Drama验证权限
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("unauthorized")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建生成提示词 - 使用详细的外貌描述,添加干净背景要求
|
||||
prompt := ""
|
||||
|
||||
// 优先使用appearance字段,它包含了最详细的外貌描述
|
||||
if character.Appearance != nil && *character.Appearance != "" {
|
||||
prompt = *character.Appearance
|
||||
} else if character.Description != nil && *character.Description != "" {
|
||||
prompt = *character.Description
|
||||
} else {
|
||||
prompt = character.Name
|
||||
}
|
||||
|
||||
// 添加角色画像和风格要求
|
||||
prompt += ", character portrait, full body or upper body shot"
|
||||
|
||||
// 添加干净背景要求 - 确保背景简洁不干扰主体
|
||||
prompt += ", simple clean background, plain solid color background, white or light gray background"
|
||||
prompt += ", studio lighting, professional photography"
|
||||
|
||||
// 添加质量和风格要求
|
||||
prompt += ", high quality, detailed, anime style, character design"
|
||||
prompt += ", no complex background, no scenery, focus on character"
|
||||
|
||||
// 调用图片生成服务
|
||||
dramaIDStr := fmt.Sprintf("%d", character.DramaID)
|
||||
imageType := "character"
|
||||
req := &GenerateImageRequest{
|
||||
DramaID: dramaIDStr,
|
||||
CharacterID: &character.ID,
|
||||
ImageType: imageType,
|
||||
Prompt: prompt,
|
||||
Provider: "openai", // 或从配置读取
|
||||
Model: modelName, // 使用用户指定的模型
|
||||
Size: "2560x1440", // 3,686,400像素,满足API最低要求(16:9比例)
|
||||
Quality: "standard",
|
||||
}
|
||||
|
||||
imageGen, err := imageService.GenerateImage(req)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to generate character image", "error", err)
|
||||
return nil, fmt.Errorf("图片生成失败: %w", err)
|
||||
}
|
||||
|
||||
// 异步处理:在后台监听图片生成完成,然后更新角色image_url
|
||||
go s.waitAndUpdateCharacterImage(character.ID, imageGen.ID)
|
||||
|
||||
// 立即返回ImageGeneration对象,让前端可以轮询状态
|
||||
s.log.Infow("Character image generation started", "character_id", characterID, "image_gen_id", imageGen.ID)
|
||||
return imageGen, nil
|
||||
}
|
||||
|
||||
// waitAndUpdateCharacterImage 后台异步等待图片生成完成并更新角色image_url
|
||||
func (s *CharacterLibraryService) waitAndUpdateCharacterImage(characterID uint, imageGenID uint) {
|
||||
maxAttempts := 60
|
||||
pollInterval := 5 * time.Second
|
||||
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
time.Sleep(pollInterval)
|
||||
|
||||
// 查询图片生成状态
|
||||
var imageGen models.ImageGeneration
|
||||
if err := s.db.First(&imageGen, imageGenID).Error; err != nil {
|
||||
s.log.Errorw("Failed to query image generation status", "error", err, "image_gen_id", imageGenID)
|
||||
continue
|
||||
}
|
||||
|
||||
// 检查是否完成
|
||||
if imageGen.Status == models.ImageStatusCompleted && imageGen.ImageURL != nil && *imageGen.ImageURL != "" {
|
||||
// 更新角色的image_url
|
||||
if err := s.db.Model(&models.Character{}).Where("id = ?", characterID).Update("image_url", *imageGen.ImageURL).Error; err != nil {
|
||||
s.log.Errorw("Failed to update character image_url", "error", err, "character_id", characterID)
|
||||
return
|
||||
}
|
||||
s.log.Infow("Character image updated successfully", "character_id", characterID, "image_url", *imageGen.ImageURL)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否失败
|
||||
if imageGen.Status == models.ImageStatusFailed {
|
||||
s.log.Errorw("Character image generation failed", "character_id", characterID, "image_gen_id", imageGenID, "error", imageGen.ErrorMsg)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Warnw("Character image generation timeout", "character_id", characterID, "image_gen_id", imageGenID)
|
||||
}
|
||||
|
||||
// UpdateCharacter 更新角色信息
|
||||
func (s *CharacterLibraryService) UpdateCharacter(characterID string, req interface{}) error {
|
||||
// 查找角色
|
||||
var character models.Character
|
||||
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("character not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证权限:查询角色所属的drama是否属于该用户
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("unauthorized")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 构建更新数据
|
||||
updates := make(map[string]interface{})
|
||||
|
||||
// 使用类型断言获取请求数据
|
||||
if reqMap, ok := req.(*struct {
|
||||
Name *string `json:"name"`
|
||||
Appearance *string `json:"appearance"`
|
||||
Personality *string `json:"personality"`
|
||||
Description *string `json:"description"`
|
||||
}); ok {
|
||||
if reqMap.Name != nil && *reqMap.Name != "" {
|
||||
updates["name"] = *reqMap.Name
|
||||
}
|
||||
if reqMap.Appearance != nil {
|
||||
updates["appearance"] = *reqMap.Appearance
|
||||
}
|
||||
if reqMap.Personality != nil {
|
||||
updates["personality"] = *reqMap.Personality
|
||||
}
|
||||
if reqMap.Description != nil {
|
||||
updates["description"] = *reqMap.Description
|
||||
}
|
||||
}
|
||||
|
||||
if len(updates) == 0 {
|
||||
return errors.New("no fields to update")
|
||||
}
|
||||
|
||||
// 更新角色信息
|
||||
if err := s.db.Model(&character).Updates(updates).Error; err != nil {
|
||||
s.log.Errorw("Failed to update character", "error", err, "character_id", characterID)
|
||||
return err
|
||||
}
|
||||
|
||||
s.log.Infow("Character updated", "character_id", characterID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchGenerateCharacterImages 批量生成角色图片(并发执行)
|
||||
func (s *CharacterLibraryService) BatchGenerateCharacterImages(characterIDs []string, imageService *ImageGenerationService, modelName string) {
|
||||
s.log.Infow("Starting batch character image generation",
|
||||
"count", len(characterIDs),
|
||||
"model", modelName)
|
||||
|
||||
// 使用 goroutine 并发生成所有角色图片
|
||||
for _, characterID := range characterIDs {
|
||||
// 为每个角色启动单独的 goroutine
|
||||
go func(charID string) {
|
||||
imageGen, err := s.GenerateCharacterImage(charID, imageService, modelName)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to generate character image in batch",
|
||||
"character_id", charID,
|
||||
"error", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Infow("Character image generated in batch",
|
||||
"character_id", charID,
|
||||
"image_gen_id", imageGen.ID)
|
||||
}(characterID)
|
||||
}
|
||||
|
||||
s.log.Infow("Batch character image generation tasks submitted",
|
||||
"total", len(characterIDs))
|
||||
}
|
||||
630
application/services/drama_service.go
Normal file
630
application/services/drama_service.go
Normal file
@@ -0,0 +1,630 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DramaService struct {
|
||||
db *gorm.DB
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewDramaService(db *gorm.DB, log *logger.Logger) *DramaService {
|
||||
return &DramaService{
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
type CreateDramaRequest struct {
|
||||
Title string `json:"title" binding:"required,min=1,max=100"`
|
||||
Description string `json:"description"`
|
||||
Genre string `json:"genre"`
|
||||
Tags string `json:"tags"`
|
||||
}
|
||||
|
||||
type UpdateDramaRequest struct {
|
||||
Title string `json:"title" binding:"omitempty,min=1,max=100"`
|
||||
Description string `json:"description"`
|
||||
Genre string `json:"genre"`
|
||||
Tags string `json:"tags"`
|
||||
Status string `json:"status" binding:"omitempty,oneof=draft planning production completed archived"`
|
||||
}
|
||||
|
||||
type DramaListQuery struct {
|
||||
Page int `form:"page,default=1"`
|
||||
PageSize int `form:"page_size,default=20"`
|
||||
Status string `form:"status"`
|
||||
Genre string `form:"genre"`
|
||||
Keyword string `form:"keyword"`
|
||||
}
|
||||
|
||||
func (s *DramaService) CreateDrama(req *CreateDramaRequest) (*models.Drama, error) {
|
||||
drama := &models.Drama{
|
||||
Title: req.Title,
|
||||
Status: "draft",
|
||||
}
|
||||
|
||||
if req.Description != "" {
|
||||
drama.Description = &req.Description
|
||||
}
|
||||
if req.Genre != "" {
|
||||
drama.Genre = &req.Genre
|
||||
}
|
||||
|
||||
if err := s.db.Create(drama).Error; err != nil {
|
||||
s.log.Errorw("Failed to create drama", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.log.Infow("Drama created", "drama_id", drama.ID)
|
||||
return drama, nil
|
||||
}
|
||||
|
||||
func (s *DramaService) GetDrama(dramaID string) (*models.Drama, error) {
|
||||
var drama models.Drama
|
||||
err := s.db.Where("id = ? ", dramaID).
|
||||
Preload("Characters"). // 加载Drama级别的角色
|
||||
Preload("Scenes"). // 加载Drama级别的场景
|
||||
Preload("Episodes.Characters"). // 加载每个章节关联的角色
|
||||
Preload("Episodes.Scenes"). // 加载每个章节关联的场景
|
||||
Preload("Episodes.Storyboards", func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order("storyboards.storyboard_number ASC")
|
||||
}).
|
||||
First(&drama).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("drama not found")
|
||||
}
|
||||
s.log.Errorw("Failed to get drama", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 统计每个剧集的时长(基于场景时长之和)
|
||||
for i := range drama.Episodes {
|
||||
totalDuration := 0
|
||||
for _, scene := range drama.Episodes[i].Storyboards {
|
||||
totalDuration += scene.Duration
|
||||
}
|
||||
// 更新剧集时长(秒转分钟,向上取整)
|
||||
durationMinutes := (totalDuration + 59) / 60
|
||||
drama.Episodes[i].Duration = durationMinutes
|
||||
|
||||
// 如果数据库中的时长与计算的不一致,更新数据库
|
||||
if drama.Episodes[i].Duration != durationMinutes {
|
||||
s.db.Model(&models.Episode{}).Where("id = ?", drama.Episodes[i].ID).Update("duration", durationMinutes)
|
||||
}
|
||||
|
||||
// 查询角色的图片生成状态
|
||||
for j := range drama.Episodes[i].Characters {
|
||||
var imageGen models.ImageGeneration
|
||||
err := s.db.Where("character_id = ? AND (status = ? OR status = ?)",
|
||||
drama.Episodes[i].Characters[j].ID, "pending", "processing").
|
||||
Order("created_at DESC").
|
||||
First(&imageGen).Error
|
||||
|
||||
if err == nil {
|
||||
// 找到生成中的记录,设置状态
|
||||
statusStr := string(imageGen.Status)
|
||||
drama.Episodes[i].Characters[j].ImageGenerationStatus = &statusStr
|
||||
if imageGen.ErrorMsg != nil {
|
||||
drama.Episodes[i].Characters[j].ImageGenerationError = imageGen.ErrorMsg
|
||||
}
|
||||
} else if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 检查是否有失败的记录
|
||||
err := s.db.Where("character_id = ? AND status = ?",
|
||||
drama.Episodes[i].Characters[j].ID, "failed").
|
||||
Order("created_at DESC").
|
||||
First(&imageGen).Error
|
||||
|
||||
if err == nil {
|
||||
statusStr := string(imageGen.Status)
|
||||
drama.Episodes[i].Characters[j].ImageGenerationStatus = &statusStr
|
||||
if imageGen.ErrorMsg != nil {
|
||||
drama.Episodes[i].Characters[j].ImageGenerationError = imageGen.ErrorMsg
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查询场景的图片生成状态
|
||||
for j := range drama.Episodes[i].Scenes {
|
||||
var imageGen models.ImageGeneration
|
||||
err := s.db.Where("scene_id = ? AND (status = ? OR status = ?)",
|
||||
drama.Episodes[i].Scenes[j].ID, "pending", "processing").
|
||||
Order("created_at DESC").
|
||||
First(&imageGen).Error
|
||||
|
||||
if err == nil {
|
||||
// 找到生成中的记录,设置状态
|
||||
statusStr := string(imageGen.Status)
|
||||
drama.Episodes[i].Scenes[j].ImageGenerationStatus = &statusStr
|
||||
if imageGen.ErrorMsg != nil {
|
||||
drama.Episodes[i].Scenes[j].ImageGenerationError = imageGen.ErrorMsg
|
||||
}
|
||||
} else if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 检查是否有失败的记录
|
||||
err := s.db.Where("scene_id = ? AND status = ?",
|
||||
drama.Episodes[i].Scenes[j].ID, "failed").
|
||||
Order("created_at DESC").
|
||||
First(&imageGen).Error
|
||||
|
||||
if err == nil {
|
||||
statusStr := string(imageGen.Status)
|
||||
drama.Episodes[i].Scenes[j].ImageGenerationStatus = &statusStr
|
||||
if imageGen.ErrorMsg != nil {
|
||||
drama.Episodes[i].Scenes[j].ImageGenerationError = imageGen.ErrorMsg
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 整合所有剧集的场景到Drama级别的Scenes字段
|
||||
sceneMap := make(map[uint]*models.Scene) // 用于去重
|
||||
for i := range drama.Episodes {
|
||||
for j := range drama.Episodes[i].Scenes {
|
||||
scene := &drama.Episodes[i].Scenes[j]
|
||||
sceneMap[scene.ID] = scene
|
||||
}
|
||||
}
|
||||
|
||||
// 将整合的场景添加到drama.Scenes
|
||||
drama.Scenes = make([]models.Scene, 0, len(sceneMap))
|
||||
for _, scene := range sceneMap {
|
||||
drama.Scenes = append(drama.Scenes, *scene)
|
||||
}
|
||||
|
||||
return &drama, nil
|
||||
}
|
||||
|
||||
func (s *DramaService) ListDramas(query *DramaListQuery) ([]models.Drama, int64, error) {
|
||||
var dramas []models.Drama
|
||||
var total int64
|
||||
|
||||
db := s.db.Model(&models.Drama{})
|
||||
|
||||
if query.Status != "" {
|
||||
db = db.Where("status = ?", query.Status)
|
||||
}
|
||||
|
||||
if query.Genre != "" {
|
||||
db = db.Where("genre = ?", query.Genre)
|
||||
}
|
||||
|
||||
if query.Keyword != "" {
|
||||
db = db.Where("title LIKE ? OR description LIKE ?", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
|
||||
}
|
||||
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
s.log.Errorw("Failed to count dramas", "error", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
err := db.Order("updated_at DESC").
|
||||
Offset(offset).
|
||||
Limit(query.PageSize).
|
||||
Preload("Episodes.Storyboards", func(db *gorm.DB) *gorm.DB {
|
||||
return db.Order("storyboards.storyboard_number ASC")
|
||||
}).
|
||||
Find(&dramas).Error
|
||||
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to list dramas", "error", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 统计每个剧本的每个剧集的时长(基于场景时长之和)
|
||||
for i := range dramas {
|
||||
for j := range dramas[i].Episodes {
|
||||
totalDuration := 0
|
||||
for _, scene := range dramas[i].Episodes[j].Storyboards {
|
||||
totalDuration += scene.Duration
|
||||
}
|
||||
// 更新剧集时长(秒转分钟,向上取整)
|
||||
durationMinutes := (totalDuration + 59) / 60
|
||||
dramas[i].Episodes[j].Duration = durationMinutes
|
||||
}
|
||||
}
|
||||
|
||||
return dramas, total, nil
|
||||
}
|
||||
|
||||
func (s *DramaService) UpdateDrama(dramaID string, req *UpdateDramaRequest) (*models.Drama, error) {
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("drama not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updates := make(map[string]interface{})
|
||||
|
||||
if req.Title != "" {
|
||||
updates["title"] = req.Title
|
||||
}
|
||||
if req.Description != "" {
|
||||
updates["description"] = req.Description
|
||||
}
|
||||
if req.Genre != "" {
|
||||
updates["genre"] = req.Genre
|
||||
}
|
||||
if req.Tags != "" {
|
||||
updates["tags"] = req.Tags
|
||||
}
|
||||
if req.Status != "" {
|
||||
updates["status"] = req.Status
|
||||
}
|
||||
|
||||
updates["updated_at"] = time.Now()
|
||||
|
||||
if err := s.db.Model(&drama).Updates(updates).Error; err != nil {
|
||||
s.log.Errorw("Failed to update drama", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.log.Infow("Drama updated", "drama_id", dramaID)
|
||||
return &drama, nil
|
||||
}
|
||||
|
||||
func (s *DramaService) DeleteDrama(dramaID string) error {
|
||||
result := s.db.Where("id = ? ", dramaID).Delete(&models.Drama{})
|
||||
|
||||
if result.Error != nil {
|
||||
s.log.Errorw("Failed to delete drama", "error", result.Error)
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return errors.New("drama not found")
|
||||
}
|
||||
|
||||
s.log.Infow("Drama deleted", "drama_id", dramaID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DramaService) GetDramaStats() (map[string]interface{}, error) {
|
||||
var total int64
|
||||
var byStatus []struct {
|
||||
Status string
|
||||
Count int64
|
||||
}
|
||||
|
||||
if err := s.db.Model(&models.Drama{}).Count(&total).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.db.Model(&models.Drama{}).
|
||||
Select("status, count(*) as count").
|
||||
Group("status").
|
||||
Scan(&byStatus).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total": total,
|
||||
"by_status": byStatus,
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
type SaveOutlineRequest struct {
|
||||
Title string `json:"title" binding:"required"`
|
||||
Summary string `json:"summary" binding:"required"`
|
||||
Genre string `json:"genre"`
|
||||
Tags []string `json:"tags"`
|
||||
}
|
||||
|
||||
type SaveCharactersRequest struct {
|
||||
Characters []models.Character `json:"characters" binding:"required"`
|
||||
EpisodeID *uint `json:"episode_id"` // 可选:如果提供则关联到指定章节
|
||||
}
|
||||
|
||||
type SaveProgressRequest struct {
|
||||
CurrentStep string `json:"current_step" binding:"required"`
|
||||
StepData map[string]interface{} `json:"step_data"`
|
||||
}
|
||||
|
||||
type SaveEpisodesRequest struct {
|
||||
Episodes []models.Episode `json:"episodes" binding:"required"`
|
||||
}
|
||||
|
||||
func (s *DramaService) SaveOutline(dramaID string, req *SaveOutlineRequest) error {
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("drama not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"title": req.Title,
|
||||
"description": req.Summary,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
if req.Genre != "" {
|
||||
updates["genre"] = req.Genre
|
||||
}
|
||||
|
||||
if len(req.Tags) > 0 {
|
||||
tagsJSON, err := json.Marshal(req.Tags)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to marshal tags", "error", err)
|
||||
return err
|
||||
}
|
||||
updates["tags"] = tagsJSON
|
||||
}
|
||||
|
||||
if err := s.db.Model(&drama).Updates(updates).Error; err != nil {
|
||||
s.log.Errorw("Failed to save outline", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.log.Infow("Outline saved", "drama_id", dramaID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DramaService) GetCharacters(dramaID string, episodeID *string) ([]models.Character, error) {
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("drama not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var characters []models.Character
|
||||
|
||||
// 如果指定了episodeID,只获取该章节关联的角色
|
||||
if episodeID != nil {
|
||||
var episode models.Episode
|
||||
if err := s.db.Preload("Characters").Where("id = ? AND drama_id = ?", *episodeID, dramaID).First(&episode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errors.New("episode not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
characters = episode.Characters
|
||||
} else {
|
||||
// 如果没有指定episodeID,获取项目的所有角色
|
||||
if err := s.db.Where("drama_id = ?", dramaID).Find(&characters).Error; err != nil {
|
||||
s.log.Errorw("Failed to get characters", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 查询每个角色的图片生成任务状态
|
||||
for i := range characters {
|
||||
// 查询该角色最新的图片生成任务
|
||||
var imageGen models.ImageGeneration
|
||||
err := s.db.Where("character_id = ?", characters[i].ID).
|
||||
Order("created_at DESC").
|
||||
First(&imageGen).Error
|
||||
|
||||
if err == nil {
|
||||
// 如果有进行中的任务,填充状态信息
|
||||
if imageGen.Status == models.ImageStatusPending || imageGen.Status == models.ImageStatusProcessing {
|
||||
statusStr := string(imageGen.Status)
|
||||
characters[i].ImageGenerationStatus = &statusStr
|
||||
} else if imageGen.Status == models.ImageStatusFailed {
|
||||
statusStr := "failed"
|
||||
characters[i].ImageGenerationStatus = &statusStr
|
||||
if imageGen.ErrorMsg != nil {
|
||||
characters[i].ImageGenerationError = imageGen.ErrorMsg
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return characters, nil
|
||||
}
|
||||
|
||||
func (s *DramaService) SaveCharacters(dramaID string, req *SaveCharactersRequest) error {
|
||||
// 转换dramaID
|
||||
id, err := strconv.ParseUint(dramaID, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid drama ID")
|
||||
}
|
||||
dramaIDUint := uint(id)
|
||||
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", dramaIDUint).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("drama not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果指定了EpisodeID,验证章节存在性
|
||||
if req.EpisodeID != nil {
|
||||
var episode models.Episode
|
||||
if err := s.db.Where("id = ? AND drama_id = ?", *req.EpisodeID, dramaIDUint).First(&episode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("episode not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 获取该项目已存在的所有角色
|
||||
var existingCharacters []models.Character
|
||||
if err := s.db.Where("drama_id = ?", dramaIDUint).Find(&existingCharacters).Error; err != nil {
|
||||
s.log.Errorw("Failed to get existing characters", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建角色名称到角色的映射
|
||||
existingCharMap := make(map[string]*models.Character)
|
||||
for i := range existingCharacters {
|
||||
existingCharMap[existingCharacters[i].Name] = &existingCharacters[i]
|
||||
}
|
||||
|
||||
// 收集需要关联到章节的角色ID
|
||||
var characterIDs []uint
|
||||
|
||||
// 创建新角色或复用已有角色
|
||||
for _, char := range req.Characters {
|
||||
if existingChar, exists := existingCharMap[char.Name]; exists {
|
||||
// 角色已存在,直接复用
|
||||
s.log.Infow("Character already exists, reusing", "name", char.Name, "character_id", existingChar.ID)
|
||||
characterIDs = append(characterIDs, existingChar.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// 角色不存在,创建新角色
|
||||
character := models.Character{
|
||||
DramaID: dramaIDUint,
|
||||
Name: char.Name,
|
||||
Role: char.Role,
|
||||
Description: char.Description,
|
||||
Personality: char.Personality,
|
||||
Appearance: char.Appearance,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&character).Error; err != nil {
|
||||
s.log.Errorw("Failed to create character", "error", err, "name", char.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
s.log.Infow("New character created", "character_id", character.ID, "name", char.Name)
|
||||
characterIDs = append(characterIDs, character.ID)
|
||||
}
|
||||
|
||||
// 如果指定了EpisodeID,建立角色与章节的关联
|
||||
if req.EpisodeID != nil && len(characterIDs) > 0 {
|
||||
var episode models.Episode
|
||||
if err := s.db.First(&episode, *req.EpisodeID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取角色对象
|
||||
var characters []models.Character
|
||||
if err := s.db.Where("id IN ?", characterIDs).Find(&characters).Error; err != nil {
|
||||
s.log.Errorw("Failed to get characters", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用GORM的Association API建立多对多关系(会自动去重)
|
||||
if err := s.db.Model(&episode).Association("Characters").Append(&characters); err != nil {
|
||||
s.log.Errorw("Failed to associate characters with episode", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.log.Infow("Characters associated with episode", "episode_id", *req.EpisodeID, "character_count", len(characterIDs))
|
||||
}
|
||||
|
||||
if err := s.db.Model(&drama).Update("updated_at", time.Now()).Error; err != nil {
|
||||
s.log.Errorw("Failed to update drama timestamp", "error", err)
|
||||
}
|
||||
|
||||
s.log.Infow("Characters saved", "drama_id", dramaID, "count", len(req.Characters))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DramaService) SaveEpisodes(dramaID string, req *SaveEpisodesRequest) error {
|
||||
// 转换dramaID
|
||||
id, err := strconv.ParseUint(dramaID, 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid drama ID")
|
||||
}
|
||||
dramaIDUint := uint(id)
|
||||
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", dramaIDUint).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("drama not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除旧剧集
|
||||
if err := s.db.Where("drama_id = ?", dramaIDUint).Delete(&models.Episode{}).Error; err != nil {
|
||||
s.log.Errorw("Failed to delete old episodes", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建新剧集(不包含场景,场景由后续步骤生成)
|
||||
for _, ep := range req.Episodes {
|
||||
episode := models.Episode{
|
||||
DramaID: dramaIDUint,
|
||||
EpisodeNum: ep.EpisodeNum,
|
||||
Title: ep.Title,
|
||||
Description: ep.Description,
|
||||
ScriptContent: ep.ScriptContent,
|
||||
Duration: ep.Duration,
|
||||
Status: "draft",
|
||||
}
|
||||
|
||||
if err := s.db.Create(&episode).Error; err != nil {
|
||||
s.log.Errorw("Failed to create episode", "error", err, "episode", ep.EpisodeNum)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.db.Model(&drama).Update("updated_at", time.Now()).Error; err != nil {
|
||||
s.log.Errorw("Failed to update drama timestamp", "error", err)
|
||||
}
|
||||
|
||||
s.log.Infow("Episodes saved", "drama_id", dramaID, "count", len(req.Episodes))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DramaService) SaveProgress(dramaID string, req *SaveProgressRequest) error {
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return errors.New("drama not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 构建metadata对象
|
||||
metadata := make(map[string]interface{})
|
||||
|
||||
// 保留现有metadata
|
||||
if drama.Metadata != nil {
|
||||
if err := json.Unmarshal(drama.Metadata, &metadata); err != nil {
|
||||
s.log.Warnw("Failed to unmarshal existing metadata", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新progress信息
|
||||
metadata["current_step"] = req.CurrentStep
|
||||
if req.StepData != nil {
|
||||
metadata["step_data"] = req.StepData
|
||||
}
|
||||
|
||||
// 序列化metadata
|
||||
metadataJSON, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to marshal metadata", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"metadata": metadataJSON,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
if err := s.db.Model(&drama).Updates(updates).Error; err != nil {
|
||||
s.log.Errorw("Failed to save progress", "error", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.log.Infow("Progress saved", "drama_id", dramaID, "step", req.CurrentStep)
|
||||
return nil
|
||||
}
|
||||
428
application/services/frame_prompt_service.go
Normal file
428
application/services/frame_prompt_service.go
Normal file
@@ -0,0 +1,428 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// FramePromptService 处理帧提示词生成
|
||||
type FramePromptService struct {
|
||||
db *gorm.DB
|
||||
aiService *AIService
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
// NewFramePromptService 创建帧提示词服务
|
||||
func NewFramePromptService(db *gorm.DB, log *logger.Logger) *FramePromptService {
|
||||
return &FramePromptService{
|
||||
db: db,
|
||||
aiService: NewAIService(db, log),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// FrameType 帧类型
|
||||
type FrameType string
|
||||
|
||||
const (
|
||||
FrameTypeFirst FrameType = "first" // 首帧
|
||||
FrameTypeKey FrameType = "key" // 关键帧
|
||||
FrameTypeLast FrameType = "last" // 尾帧
|
||||
FrameTypePanel FrameType = "panel" // 分镜板(3格组合)
|
||||
FrameTypeAction FrameType = "action" // 动作序列(5格)
|
||||
)
|
||||
|
||||
// GenerateFramePromptRequest 生成帧提示词请求
|
||||
type GenerateFramePromptRequest struct {
|
||||
StoryboardID string `json:"storyboard_id"`
|
||||
FrameType FrameType `json:"frame_type"`
|
||||
// 可选参数
|
||||
PanelCount int `json:"panel_count,omitempty"` // 分镜板格数,默认3
|
||||
}
|
||||
|
||||
// FramePromptResponse 帧提示词响应
|
||||
type FramePromptResponse struct {
|
||||
FrameType FrameType `json:"frame_type"`
|
||||
SingleFrame *SingleFramePrompt `json:"single_frame,omitempty"` // 单帧提示词
|
||||
MultiFrame *MultiFramePrompt `json:"multi_frame,omitempty"` // 多帧提示词
|
||||
}
|
||||
|
||||
// SingleFramePrompt 单帧提示词
|
||||
type SingleFramePrompt struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// MultiFramePrompt 多帧提示词
|
||||
type MultiFramePrompt struct {
|
||||
Layout string `json:"layout"` // horizontal_3, grid_2x2 等
|
||||
Frames []SingleFramePrompt `json:"frames"`
|
||||
}
|
||||
|
||||
// GenerateFramePrompt 生成指定类型的帧提示词并保存到frame_prompts表
|
||||
func (s *FramePromptService) GenerateFramePrompt(req GenerateFramePromptRequest) (*FramePromptResponse, error) {
|
||||
// 查询分镜信息
|
||||
var storyboard models.Storyboard
|
||||
if err := s.db.Preload("Characters").First(&storyboard, req.StoryboardID).Error; err != nil {
|
||||
return nil, fmt.Errorf("storyboard not found: %w", err)
|
||||
}
|
||||
|
||||
// 获取场景信息
|
||||
var scene *models.Scene
|
||||
if storyboard.SceneID != nil {
|
||||
scene = &models.Scene{}
|
||||
if err := s.db.First(scene, *storyboard.SceneID).Error; err != nil {
|
||||
s.log.Warnw("Scene not found", "scene_id", *storyboard.SceneID)
|
||||
scene = nil
|
||||
}
|
||||
}
|
||||
|
||||
response := &FramePromptResponse{
|
||||
FrameType: req.FrameType,
|
||||
}
|
||||
|
||||
// 生成提示词
|
||||
switch req.FrameType {
|
||||
case FrameTypeFirst:
|
||||
response.SingleFrame = s.generateFirstFrame(storyboard, scene)
|
||||
// 保存单帧提示词
|
||||
s.saveFramePrompt(req.StoryboardID, string(req.FrameType), response.SingleFrame.Prompt, response.SingleFrame.Description, "")
|
||||
case FrameTypeKey:
|
||||
response.SingleFrame = s.generateKeyFrame(storyboard, scene)
|
||||
s.saveFramePrompt(req.StoryboardID, string(req.FrameType), response.SingleFrame.Prompt, response.SingleFrame.Description, "")
|
||||
case FrameTypeLast:
|
||||
response.SingleFrame = s.generateLastFrame(storyboard, scene)
|
||||
s.saveFramePrompt(req.StoryboardID, string(req.FrameType), response.SingleFrame.Prompt, response.SingleFrame.Description, "")
|
||||
case FrameTypePanel:
|
||||
count := req.PanelCount
|
||||
if count == 0 {
|
||||
count = 3
|
||||
}
|
||||
response.MultiFrame = s.generatePanelFrames(storyboard, scene, count)
|
||||
// 保存多帧提示词(合并为一条记录)
|
||||
var prompts []string
|
||||
for _, frame := range response.MultiFrame.Frames {
|
||||
prompts = append(prompts, frame.Prompt)
|
||||
}
|
||||
combinedPrompt := strings.Join(prompts, "\n---\n")
|
||||
s.saveFramePrompt(req.StoryboardID, string(req.FrameType), combinedPrompt, "分镜板组合提示词", response.MultiFrame.Layout)
|
||||
case FrameTypeAction:
|
||||
response.MultiFrame = s.generateActionSequence(storyboard, scene)
|
||||
var prompts []string
|
||||
for _, frame := range response.MultiFrame.Frames {
|
||||
prompts = append(prompts, frame.Prompt)
|
||||
}
|
||||
combinedPrompt := strings.Join(prompts, "\n---\n")
|
||||
s.saveFramePrompt(req.StoryboardID, string(req.FrameType), combinedPrompt, "动作序列组合提示词", response.MultiFrame.Layout)
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported frame type: %s", req.FrameType)
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// saveFramePrompt 保存帧提示词到数据库
|
||||
func (s *FramePromptService) saveFramePrompt(storyboardID, frameType, prompt, description, layout string) {
|
||||
framePrompt := models.FramePrompt{
|
||||
StoryboardID: uint(mustParseUint(storyboardID)),
|
||||
FrameType: frameType,
|
||||
Prompt: prompt,
|
||||
}
|
||||
|
||||
if description != "" {
|
||||
framePrompt.Description = &description
|
||||
}
|
||||
if layout != "" {
|
||||
framePrompt.Layout = &layout
|
||||
}
|
||||
|
||||
// 先删除同类型的旧记录(保持最新)
|
||||
s.db.Where("storyboard_id = ? AND frame_type = ?", storyboardID, frameType).Delete(&models.FramePrompt{})
|
||||
|
||||
// 插入新记录
|
||||
if err := s.db.Create(&framePrompt).Error; err != nil {
|
||||
s.log.Warnw("Failed to save frame prompt", "error", err, "storyboard_id", storyboardID, "frame_type", frameType)
|
||||
}
|
||||
}
|
||||
|
||||
// mustParseUint 辅助函数
|
||||
func mustParseUint(s string) uint64 {
|
||||
var result uint64
|
||||
fmt.Sscanf(s, "%d", &result)
|
||||
return result
|
||||
}
|
||||
|
||||
// generateFirstFrame 生成首帧提示词
|
||||
func (s *FramePromptService) generateFirstFrame(sb models.Storyboard, scene *models.Scene) *SingleFramePrompt {
|
||||
// 构建上下文信息
|
||||
contextInfo := s.buildStoryboardContext(sb, scene)
|
||||
|
||||
// 构建AI提示词
|
||||
systemPrompt := `你是一个专业的图像生成提示词专家。请根据提供的镜头信息,生成适合用于AI图像生成的提示词。
|
||||
|
||||
重要:这是镜头的首帧 - 一个完全静态的画面,展示动作发生之前的初始状态。
|
||||
|
||||
要求:
|
||||
1. 直接输出提示词,不要任何解释说明
|
||||
2. 可以使用中文或英文,用逗号分隔关键词
|
||||
3. 只描述静态视觉元素:场景环境、角色姿态、表情、氛围、光线
|
||||
4. 不要包含任何动作动词(如:猛然、弹起、坐直、抓住等)
|
||||
5. 描述角色处于动作发生前的状态(如:躺在床上、站立、坐着等静态姿态)
|
||||
6. 适合动画风格(anime style)
|
||||
|
||||
示例格式:
|
||||
Anime style, 城市公寓卧室, 凌晨, 昏暗房间, 床上, 年轻男子躺着, 表情平静, 闭眼睡眠, 柔和光线, 静谧氛围, 中景, 平视`
|
||||
|
||||
userPrompt := fmt.Sprintf(`镜头信息:
|
||||
%s
|
||||
|
||||
请直接生成首帧的图像提示词,不要任何解释:`, contextInfo)
|
||||
|
||||
// 调用AI生成
|
||||
prompt, err := s.aiService.GenerateText(userPrompt, systemPrompt)
|
||||
if err != nil {
|
||||
s.log.Warnw("AI generation failed, using fallback", "error", err)
|
||||
// 降级方案:使用简单拼接
|
||||
prompt = s.buildFallbackPrompt(sb, scene, "first frame, static shot")
|
||||
}
|
||||
|
||||
// 如果AI返回空字符串,使用降级方案
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
s.log.Warnw("AI returned empty prompt, using fallback", "storyboard_id", sb.ID)
|
||||
prompt = s.buildFallbackPrompt(sb, scene, "first frame, static shot")
|
||||
}
|
||||
|
||||
return &SingleFramePrompt{
|
||||
Prompt: prompt,
|
||||
Description: "镜头开始的静态画面,展示初始状态",
|
||||
}
|
||||
}
|
||||
|
||||
// generateKeyFrame 生成关键帧提示词
|
||||
func (s *FramePromptService) generateKeyFrame(sb models.Storyboard, scene *models.Scene) *SingleFramePrompt {
|
||||
// 构建上下文信息
|
||||
contextInfo := s.buildStoryboardContext(sb, scene)
|
||||
|
||||
// 构建AI提示词
|
||||
systemPrompt := `你是一个专业的图像生成提示词专家。请根据提供的镜头信息,生成适合用于AI图像生成的提示词。
|
||||
|
||||
重要:这是镜头的关键帧 - 捕捉动作最激烈、最精彩的瞬间。
|
||||
|
||||
要求:
|
||||
1. 直接输出提示词,不要任何解释说明
|
||||
2. 可以使用中文或英文,用逗号分隔关键词
|
||||
3. 重点描述动作的高潮瞬间:身体姿态、运动轨迹、力量感
|
||||
4. 包含动态元素:动作模糊、速度线、冲击感
|
||||
5. 强调表情和情绪的极致状态
|
||||
6. 适合动画风格(anime style)
|
||||
|
||||
示例格式:
|
||||
Anime style, 城市街道, 白天, 男子全力冲刺, 身体前倾, 动作模糊, 速度线, 汗水飞溅, 表情坚毅, 紧张氛围, 动态镜头, 中景`
|
||||
|
||||
userPrompt := fmt.Sprintf(`镜头信息:
|
||||
%s
|
||||
|
||||
请直接生成关键帧的图像提示词,不要任何解释:`, contextInfo)
|
||||
|
||||
// 调用AI生成
|
||||
prompt, err := s.aiService.GenerateText(userPrompt, systemPrompt)
|
||||
if err != nil {
|
||||
s.log.Warnw("AI generation failed, using fallback", "error", err)
|
||||
prompt = s.buildFallbackPrompt(sb, scene, "key frame, dynamic action")
|
||||
}
|
||||
|
||||
// 如果AI返回空字符串,使用降级方案
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
s.log.Warnw("AI returned empty prompt, using fallback", "storyboard_id", sb.ID)
|
||||
prompt = s.buildFallbackPrompt(sb, scene, "key frame, dynamic action")
|
||||
}
|
||||
|
||||
return &SingleFramePrompt{
|
||||
Prompt: prompt,
|
||||
Description: "动作高潮瞬间,展示关键动作",
|
||||
}
|
||||
}
|
||||
|
||||
// generateLastFrame 生成尾帧提示词
|
||||
func (s *FramePromptService) generateLastFrame(sb models.Storyboard, scene *models.Scene) *SingleFramePrompt {
|
||||
// 构建上下文信息
|
||||
contextInfo := s.buildStoryboardContext(sb, scene)
|
||||
|
||||
// 构建AI提示词
|
||||
systemPrompt := `你是一个专业的图像生成提示词专家。请根据提供的镜头信息,生成适合用于AI图像生成的提示词。
|
||||
|
||||
重要:这是镜头的尾帧 - 一个静态画面,展示动作结束后的最终状态和结果。
|
||||
|
||||
要求:
|
||||
1. 直接输出提示词,不要任何解释说明
|
||||
2. 可以使用中文或英文,用逗号分隔关键词
|
||||
3. 只描述静态的最终状态:角色姿态、表情、环境变化
|
||||
4. 不要包含动作过程,只展示动作的结果和余韵
|
||||
5. 强调情绪的余波和氛围的沉淀
|
||||
6. 适合动画风格(anime style)
|
||||
|
||||
示例格式:
|
||||
Anime style, 房间内, 黄昏, 男子坐在椅子上, 身体放松, 表情疲惫, 长出一口气, 汗水滴落, 平静氛围, 静态镜头, 中景`
|
||||
|
||||
userPrompt := fmt.Sprintf(`镜头信息:
|
||||
%s
|
||||
|
||||
请直接生成尾帧的图像提示词,不要任何解释:`, contextInfo)
|
||||
|
||||
// 调用AI生成
|
||||
prompt, err := s.aiService.GenerateText(userPrompt, systemPrompt)
|
||||
if err != nil {
|
||||
s.log.Warnw("AI generation failed, using fallback", "error", err)
|
||||
prompt = s.buildFallbackPrompt(sb, scene, "last frame, final state")
|
||||
}
|
||||
|
||||
// 如果AI返回空字符串,使用降级方案
|
||||
prompt = strings.TrimSpace(prompt)
|
||||
if prompt == "" {
|
||||
s.log.Warnw("AI returned empty prompt, using fallback", "storyboard_id", sb.ID)
|
||||
prompt = s.buildFallbackPrompt(sb, scene, "last frame, final state")
|
||||
}
|
||||
|
||||
return &SingleFramePrompt{
|
||||
Prompt: prompt,
|
||||
Description: "镜头结束画面,展示最终状态和结果",
|
||||
}
|
||||
}
|
||||
|
||||
// generatePanelFrames 生成分镜板(多格组合)
|
||||
func (s *FramePromptService) generatePanelFrames(sb models.Storyboard, scene *models.Scene, count int) *MultiFramePrompt {
|
||||
layout := fmt.Sprintf("horizontal_%d", count)
|
||||
|
||||
frames := make([]SingleFramePrompt, count)
|
||||
|
||||
// 固定生成:首帧 -> 关键帧 -> 尾帧
|
||||
if count == 3 {
|
||||
frames[0] = *s.generateFirstFrame(sb, scene)
|
||||
frames[0].Description = "第1格:初始状态"
|
||||
|
||||
frames[1] = *s.generateKeyFrame(sb, scene)
|
||||
frames[1].Description = "第2格:动作高潮"
|
||||
|
||||
frames[2] = *s.generateLastFrame(sb, scene)
|
||||
frames[2].Description = "第3格:最终状态"
|
||||
} else if count == 4 {
|
||||
// 4格:首帧 -> 中间帧1 -> 中间帧2 -> 尾帧
|
||||
frames[0] = *s.generateFirstFrame(sb, scene)
|
||||
frames[1] = *s.generateKeyFrame(sb, scene)
|
||||
frames[2] = *s.generateKeyFrame(sb, scene)
|
||||
frames[3] = *s.generateLastFrame(sb, scene)
|
||||
}
|
||||
|
||||
return &MultiFramePrompt{
|
||||
Layout: layout,
|
||||
Frames: frames,
|
||||
}
|
||||
}
|
||||
|
||||
// generateActionSequence 生成动作序列(5-8格)
|
||||
func (s *FramePromptService) generateActionSequence(sb models.Storyboard, scene *models.Scene) *MultiFramePrompt {
|
||||
// 将动作分解为5个步骤
|
||||
frames := make([]SingleFramePrompt, 5)
|
||||
|
||||
// 简化实现:均匀分布从首帧到尾帧
|
||||
frames[0] = *s.generateFirstFrame(sb, scene)
|
||||
frames[1] = *s.generateKeyFrame(sb, scene)
|
||||
frames[2] = *s.generateKeyFrame(sb, scene)
|
||||
frames[3] = *s.generateKeyFrame(sb, scene)
|
||||
frames[4] = *s.generateLastFrame(sb, scene)
|
||||
|
||||
return &MultiFramePrompt{
|
||||
Layout: "horizontal_5",
|
||||
Frames: frames,
|
||||
}
|
||||
}
|
||||
|
||||
// buildStoryboardContext 构建镜头上下文信息
|
||||
func (s *FramePromptService) buildStoryboardContext(sb models.Storyboard, scene *models.Scene) string {
|
||||
var parts []string
|
||||
|
||||
// 镜头描述(最重要)
|
||||
if sb.Description != nil && *sb.Description != "" {
|
||||
parts = append(parts, fmt.Sprintf("镜头描述: %s", *sb.Description))
|
||||
}
|
||||
|
||||
// 场景信息
|
||||
if scene != nil {
|
||||
parts = append(parts, fmt.Sprintf("场景: %s, %s", scene.Location, scene.Time))
|
||||
} else if sb.Location != nil && sb.Time != nil {
|
||||
parts = append(parts, fmt.Sprintf("场景: %s, %s", *sb.Location, *sb.Time))
|
||||
}
|
||||
|
||||
// 角色
|
||||
if len(sb.Characters) > 0 {
|
||||
var charNames []string
|
||||
for _, char := range sb.Characters {
|
||||
charNames = append(charNames, char.Name)
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("角色: %s", strings.Join(charNames, ", ")))
|
||||
}
|
||||
|
||||
// 动作
|
||||
if sb.Action != nil && *sb.Action != "" {
|
||||
parts = append(parts, fmt.Sprintf("动作: %s", *sb.Action))
|
||||
}
|
||||
|
||||
// 结果
|
||||
if sb.Result != nil && *sb.Result != "" {
|
||||
parts = append(parts, fmt.Sprintf("结果: %s", *sb.Result))
|
||||
}
|
||||
|
||||
// 对白
|
||||
if sb.Dialogue != nil && *sb.Dialogue != "" {
|
||||
parts = append(parts, fmt.Sprintf("对白: %s", *sb.Dialogue))
|
||||
}
|
||||
|
||||
// 氛围
|
||||
if sb.Atmosphere != nil && *sb.Atmosphere != "" {
|
||||
parts = append(parts, fmt.Sprintf("氛围: %s", *sb.Atmosphere))
|
||||
}
|
||||
|
||||
// 镜头参数
|
||||
if sb.ShotType != nil {
|
||||
parts = append(parts, fmt.Sprintf("景别: %s", *sb.ShotType))
|
||||
}
|
||||
if sb.Angle != nil {
|
||||
parts = append(parts, fmt.Sprintf("角度: %s", *sb.Angle))
|
||||
}
|
||||
if sb.Movement != nil {
|
||||
parts = append(parts, fmt.Sprintf("运镜: %s", *sb.Movement))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "\n")
|
||||
}
|
||||
|
||||
// buildFallbackPrompt 构建降级提示词(AI失败时使用)
|
||||
func (s *FramePromptService) buildFallbackPrompt(sb models.Storyboard, scene *models.Scene, suffix string) string {
|
||||
var parts []string
|
||||
|
||||
// 场景
|
||||
if scene != nil {
|
||||
parts = append(parts, fmt.Sprintf("%s, %s", scene.Location, scene.Time))
|
||||
}
|
||||
|
||||
// 角色
|
||||
if len(sb.Characters) > 0 {
|
||||
for _, char := range sb.Characters {
|
||||
parts = append(parts, char.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// 氛围
|
||||
if sb.Atmosphere != nil {
|
||||
parts = append(parts, *sb.Atmosphere)
|
||||
}
|
||||
|
||||
parts = append(parts, "anime style", suffix)
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
984
application/services/image_generation_service.go
Normal file
984
application/services/image_generation_service.go
Normal file
@@ -0,0 +1,984 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
models "github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/infrastructure/storage"
|
||||
"github.com/drama-generator/backend/pkg/ai"
|
||||
"github.com/drama-generator/backend/pkg/image"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"github.com/drama-generator/backend/pkg/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ImageGenerationService struct {
|
||||
db *gorm.DB
|
||||
aiService *AIService
|
||||
transferService *ResourceTransferService
|
||||
localStorage *storage.LocalStorage
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
// truncateImageURL 截断图片 URL,避免 base64 格式的 URL 占满日志
|
||||
func truncateImageURL(url string) string {
|
||||
if url == "" {
|
||||
return ""
|
||||
}
|
||||
// 如果是 data URI 格式(base64),只显示前缀
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
if len(url) > 50 {
|
||||
return url[:50] + "...[base64 data]"
|
||||
}
|
||||
}
|
||||
// 普通 URL 如果过长也截断
|
||||
if len(url) > 100 {
|
||||
return url[:100] + "..."
|
||||
}
|
||||
return url
|
||||
}
|
||||
|
||||
func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService {
|
||||
return &ImageGenerationService{
|
||||
db: db,
|
||||
aiService: NewAIService(db, log),
|
||||
transferService: transferService,
|
||||
localStorage: localStorage,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// GetDB 获取数据库连接
|
||||
func (s *ImageGenerationService) GetDB() *gorm.DB {
|
||||
return s.db
|
||||
}
|
||||
|
||||
type GenerateImageRequest struct {
|
||||
StoryboardID *uint `json:"storyboard_id"`
|
||||
DramaID string `json:"drama_id" binding:"required"`
|
||||
SceneID *uint `json:"scene_id"`
|
||||
CharacterID *uint `json:"character_id"`
|
||||
ImageType string `json:"image_type"` // character, scene, storyboard
|
||||
FrameType *string `json:"frame_type"` // first, key, last, panel, action
|
||||
Prompt string `json:"prompt" binding:"required,min=5,max=2000"`
|
||||
NegativePrompt *string `json:"negative_prompt"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Size string `json:"size"`
|
||||
Quality string `json:"quality"`
|
||||
Style *string `json:"style"`
|
||||
Steps *int `json:"steps"`
|
||||
CfgScale *float64 `json:"cfg_scale"`
|
||||
Seed *int64 `json:"seed"`
|
||||
Width *int `json:"width"`
|
||||
Height *int `json:"height"`
|
||||
ReferenceImages []string `json:"reference_images"` // 参考图片URL列表
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) GenerateImage(request *GenerateImageRequest) (*models.ImageGeneration, error) {
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", request.DramaID).First(&drama).Error; err != nil {
|
||||
return nil, fmt.Errorf("drama not found")
|
||||
}
|
||||
|
||||
// 注意:SceneID可能指向Scene或Storyboard表,调用方已经做过权限验证,这里不再重复验证
|
||||
|
||||
provider := request.Provider
|
||||
if provider == "" {
|
||||
provider = "openai"
|
||||
}
|
||||
|
||||
// 序列化参考图片
|
||||
var referenceImagesJSON []byte
|
||||
if len(request.ReferenceImages) > 0 {
|
||||
referenceImagesJSON, _ = json.Marshal(request.ReferenceImages)
|
||||
}
|
||||
|
||||
// 转换DramaID
|
||||
dramaIDParsed, err := strconv.ParseUint(request.DramaID, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid drama ID")
|
||||
}
|
||||
|
||||
// 设置默认图片类型
|
||||
imageType := request.ImageType
|
||||
if imageType == "" {
|
||||
imageType = string(models.ImageTypeStoryboard)
|
||||
}
|
||||
|
||||
imageGen := &models.ImageGeneration{
|
||||
StoryboardID: request.StoryboardID,
|
||||
DramaID: uint(dramaIDParsed),
|
||||
SceneID: request.SceneID,
|
||||
CharacterID: request.CharacterID,
|
||||
ImageType: imageType,
|
||||
FrameType: request.FrameType,
|
||||
Provider: provider,
|
||||
Prompt: request.Prompt,
|
||||
NegPrompt: request.NegativePrompt,
|
||||
Model: request.Model,
|
||||
Size: request.Size,
|
||||
ReferenceImages: referenceImagesJSON,
|
||||
Quality: request.Quality,
|
||||
Style: request.Style,
|
||||
Steps: request.Steps,
|
||||
CfgScale: request.CfgScale,
|
||||
Seed: request.Seed,
|
||||
Width: request.Width,
|
||||
Height: request.Height,
|
||||
Status: models.ImageStatusPending,
|
||||
}
|
||||
|
||||
if err := s.db.Create(imageGen).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create record: %w", err)
|
||||
}
|
||||
|
||||
go s.ProcessImageGeneration(imageGen.ID)
|
||||
|
||||
return imageGen, nil
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) ProcessImageGeneration(imageGenID uint) {
|
||||
var imageGen models.ImageGeneration
|
||||
if err := s.db.First(&imageGen, imageGenID).Error; err != nil {
|
||||
s.log.Errorw("Failed to load image generation", "error", err, "id", imageGenID)
|
||||
return
|
||||
}
|
||||
|
||||
s.db.Model(&imageGen).Update("status", models.ImageStatusProcessing)
|
||||
|
||||
// 如果关联了background,同步更新background为generating状态
|
||||
if imageGen.StoryboardID != nil {
|
||||
if err := s.db.Model(&models.Scene{}).Where("id = ?", *imageGen.StoryboardID).Update("status", "generating").Error; err != nil {
|
||||
s.log.Warnw("Failed to update background status to generating", "scene_id", *imageGen.StoryboardID, "error", err)
|
||||
} else {
|
||||
s.log.Infow("Background status updated to generating", "scene_id", *imageGen.StoryboardID)
|
||||
}
|
||||
}
|
||||
|
||||
client, err := s.getImageClientWithModel(imageGen.Provider, imageGen.Model)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to get image client", "error", err, "provider", imageGen.Provider, "model", imageGen.Model)
|
||||
s.updateImageGenError(imageGenID, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 解析参考图片
|
||||
var referenceImages []string
|
||||
if len(imageGen.ReferenceImages) > 0 {
|
||||
if err := json.Unmarshal(imageGen.ReferenceImages, &referenceImages); err == nil {
|
||||
s.log.Infow("Using reference images for generation",
|
||||
"id", imageGenID,
|
||||
"reference_count", len(referenceImages),
|
||||
"references", referenceImages)
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Infow("Starting image generation", "id", imageGenID, "prompt", imageGen.Prompt, "provider", imageGen.Provider)
|
||||
|
||||
var opts []image.ImageOption
|
||||
if imageGen.NegPrompt != nil && *imageGen.NegPrompt != "" {
|
||||
opts = append(opts, image.WithNegativePrompt(*imageGen.NegPrompt))
|
||||
}
|
||||
if imageGen.Size != "" {
|
||||
opts = append(opts, image.WithSize(imageGen.Size))
|
||||
}
|
||||
if imageGen.Quality != "" {
|
||||
opts = append(opts, image.WithQuality(imageGen.Quality))
|
||||
}
|
||||
if imageGen.Style != nil && *imageGen.Style != "" {
|
||||
opts = append(opts, image.WithStyle(*imageGen.Style))
|
||||
}
|
||||
if imageGen.Steps != nil {
|
||||
opts = append(opts, image.WithSteps(*imageGen.Steps))
|
||||
}
|
||||
if imageGen.CfgScale != nil {
|
||||
opts = append(opts, image.WithCfgScale(*imageGen.CfgScale))
|
||||
}
|
||||
if imageGen.Seed != nil {
|
||||
opts = append(opts, image.WithSeed(*imageGen.Seed))
|
||||
}
|
||||
if imageGen.Model != "" {
|
||||
opts = append(opts, image.WithModel(imageGen.Model))
|
||||
}
|
||||
if imageGen.Width != nil && imageGen.Height != nil {
|
||||
opts = append(opts, image.WithDimensions(*imageGen.Width, *imageGen.Height))
|
||||
}
|
||||
// 添加参考图片
|
||||
if len(referenceImages) > 0 {
|
||||
opts = append(opts, image.WithReferenceImages(referenceImages))
|
||||
}
|
||||
|
||||
result, err := client.GenerateImage(imageGen.Prompt, opts...)
|
||||
if err != nil {
|
||||
s.log.Errorw("Image generation API call failed", "error", err, "id", imageGenID, "prompt", imageGen.Prompt)
|
||||
s.updateImageGenError(imageGenID, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Infow("Image generation API call completed", "id", imageGenID, "completed", result.Completed, "has_url", result.ImageURL != "")
|
||||
|
||||
if !result.Completed {
|
||||
s.db.Model(&imageGen).Updates(map[string]interface{}{
|
||||
"status": models.ImageStatusProcessing,
|
||||
"task_id": result.TaskID,
|
||||
})
|
||||
go s.pollTaskStatus(imageGenID, client, result.TaskID)
|
||||
return
|
||||
}
|
||||
|
||||
s.completeImageGeneration(imageGenID, result)
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) pollTaskStatus(imageGenID uint, client image.ImageClient, taskID string) {
|
||||
maxAttempts := 60
|
||||
pollInterval := 5 * time.Second
|
||||
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
time.Sleep(pollInterval)
|
||||
|
||||
result, err := client.GetTaskStatus(taskID)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to get task status", "error", err, "task_id", taskID)
|
||||
continue
|
||||
}
|
||||
|
||||
if result.Completed {
|
||||
s.completeImageGeneration(imageGenID, result)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
s.updateImageGenError(imageGenID, result.Error)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.updateImageGenError(imageGenID, "timeout: image generation took too long")
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result *image.ImageResult) {
|
||||
now := time.Now()
|
||||
|
||||
// 下载图片到本地存储(仅用于缓存,不更新数据库)
|
||||
// 仅下载 HTTP/HTTPS URL,跳过 data URI
|
||||
if s.localStorage != nil && result.ImageURL != "" &&
|
||||
(strings.HasPrefix(result.ImageURL, "http://") || strings.HasPrefix(result.ImageURL, "https://")) {
|
||||
_, err := s.localStorage.DownloadFromURL(result.ImageURL, "images")
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if len(errStr) > 200 {
|
||||
errStr = errStr[:200] + "..."
|
||||
}
|
||||
s.log.Warnw("Failed to download image to local storage",
|
||||
"error", errStr,
|
||||
"id", imageGenID,
|
||||
"original_url", truncateImageURL(result.ImageURL))
|
||||
} else {
|
||||
s.log.Infow("Image downloaded to local storage for caching",
|
||||
"id", imageGenID,
|
||||
"original_url", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
|
||||
// 数据库中保持使用原始URL
|
||||
updates := map[string]interface{}{
|
||||
"status": models.ImageStatusCompleted,
|
||||
"image_url": result.ImageURL,
|
||||
"completed_at": now,
|
||||
}
|
||||
|
||||
if result.Width > 0 {
|
||||
updates["width"] = result.Width
|
||||
}
|
||||
if result.Height > 0 {
|
||||
updates["height"] = result.Height
|
||||
}
|
||||
|
||||
// 更新image_generation记录
|
||||
var imageGen models.ImageGeneration
|
||||
if err := s.db.Where("id = ?", imageGenID).First(&imageGen).Error; err != nil {
|
||||
s.log.Errorw("Failed to load image generation", "error", err, "id", imageGenID)
|
||||
return
|
||||
}
|
||||
|
||||
s.db.Model(&models.ImageGeneration{}).Where("id = ?", imageGenID).Updates(updates)
|
||||
s.log.Infow("Image generation completed", "id", imageGenID)
|
||||
|
||||
// 如果关联了storyboard,同步更新storyboard的composed_image
|
||||
if imageGen.StoryboardID != nil {
|
||||
if err := s.db.Model(&models.Storyboard{}).Where("id = ?", *imageGen.StoryboardID).Update("composed_image", result.ImageURL).Error; err != nil {
|
||||
s.log.Errorw("Failed to update storyboard composed_image", "error", err, "storyboard_id", *imageGen.StoryboardID)
|
||||
} else {
|
||||
s.log.Infow("Storyboard updated with composed image",
|
||||
"storyboard_id", *imageGen.StoryboardID,
|
||||
"composed_image", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
|
||||
// 如果关联了scene,同步更新scene的image_url和status(仅当ImageType是scene时)
|
||||
if imageGen.SceneID != nil && imageGen.ImageType == string(models.ImageTypeScene) {
|
||||
sceneUpdates := map[string]interface{}{
|
||||
"status": "generated",
|
||||
"image_url": result.ImageURL,
|
||||
}
|
||||
if err := s.db.Model(&models.Scene{}).Where("id = ?", *imageGen.SceneID).Updates(sceneUpdates).Error; err != nil {
|
||||
s.log.Errorw("Failed to update scene", "error", err, "scene_id", *imageGen.SceneID)
|
||||
} else {
|
||||
s.log.Infow("Scene updated with generated image",
|
||||
"scene_id", *imageGen.SceneID,
|
||||
"image_url", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
|
||||
// 如果关联了角色,同步更新角色的image_url
|
||||
if imageGen.CharacterID != nil {
|
||||
if err := s.db.Model(&models.Character{}).Where("id = ?", *imageGen.CharacterID).Update("image_url", result.ImageURL).Error; err != nil {
|
||||
s.log.Errorw("Failed to update character image_url", "error", err, "character_id", *imageGen.CharacterID)
|
||||
} else {
|
||||
s.log.Infow("Character updated with generated image",
|
||||
"character_id", *imageGen.CharacterID,
|
||||
"image_url", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) updateImageGenError(imageGenID uint, errorMsg string) {
|
||||
// 先获取image_generation记录
|
||||
var imageGen models.ImageGeneration
|
||||
if err := s.db.Where("id = ?", imageGenID).First(&imageGen).Error; err != nil {
|
||||
s.log.Errorw("Failed to load image generation", "error", err, "id", imageGenID)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新image_generation状态
|
||||
s.db.Model(&models.ImageGeneration{}).Where("id = ?", imageGenID).Updates(map[string]interface{}{
|
||||
"status": models.ImageStatusFailed,
|
||||
"error_msg": errorMsg,
|
||||
})
|
||||
s.log.Errorw("Image generation failed", "id", imageGenID, "error", errorMsg)
|
||||
|
||||
// 如果关联了scene,同步更新scene为失败状态
|
||||
if imageGen.SceneID != nil {
|
||||
s.db.Model(&models.Scene{}).Where("id = ?", *imageGen.SceneID).Update("status", "failed")
|
||||
s.log.Warnw("Scene marked as failed", "scene_id", *imageGen.SceneID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) getImageClient(provider string) (image.ImageClient, error) {
|
||||
config, err := s.aiService.GetDefaultConfig("image")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no image AI config found: %w", err)
|
||||
}
|
||||
|
||||
// 使用第一个模型
|
||||
model := ""
|
||||
if len(config.Model) > 0 {
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
// 使用配置中的 provider,如果没有则使用传入的 provider
|
||||
actualProvider := config.Provider
|
||||
if actualProvider == "" {
|
||||
actualProvider = provider
|
||||
}
|
||||
|
||||
// 根据 provider 自动设置默认端点
|
||||
var endpoint string
|
||||
var queryEndpoint string
|
||||
|
||||
switch actualProvider {
|
||||
case "openai", "dalle":
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "chatfire":
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "volcengine", "volces", "doubao":
|
||||
endpoint = "/images/generations"
|
||||
queryEndpoint = ""
|
||||
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
case "gemini", "google":
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
default:
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
// getImageClientWithModel 根据模型名称获取图片客户端
|
||||
func (s *ImageGenerationService) getImageClientWithModel(provider string, modelName string) (image.ImageClient, error) {
|
||||
var config *models.AIServiceConfig
|
||||
var err error
|
||||
|
||||
// 如果指定了模型,尝试获取对应的配置
|
||||
if modelName != "" {
|
||||
config, err = s.aiService.GetConfigForModel("image", modelName)
|
||||
if err != nil {
|
||||
s.log.Warnw("Failed to get config for model, using default", "model", modelName, "error", err)
|
||||
config, err = s.aiService.GetDefaultConfig("image")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no image AI config found: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
config, err = s.aiService.GetDefaultConfig("image")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no image AI config found: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 使用指定的模型或配置中的第一个模型
|
||||
model := modelName
|
||||
if model == "" && len(config.Model) > 0 {
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
// 使用配置中的 provider,如果没有则使用传入的 provider
|
||||
actualProvider := config.Provider
|
||||
if actualProvider == "" {
|
||||
actualProvider = provider
|
||||
}
|
||||
|
||||
// 根据 provider 自动设置默认端点
|
||||
var endpoint string
|
||||
var queryEndpoint string
|
||||
|
||||
switch actualProvider {
|
||||
case "openai", "dalle":
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "chatfire":
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "volcengine", "volces", "doubao":
|
||||
endpoint = "/images/generations"
|
||||
queryEndpoint = ""
|
||||
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
case "gemini", "google":
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
default:
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) GetImageGeneration(imageGenID uint) (*models.ImageGeneration, error) {
|
||||
var imageGen models.ImageGeneration
|
||||
if err := s.db.Where("id = ? ", imageGenID).First(&imageGen).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &imageGen, nil
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) ListImageGenerations(dramaID *uint, sceneID *uint, storyboardID *uint, frameType string, status string, page, pageSize int) ([]models.ImageGeneration, int64, error) {
|
||||
query := s.db.Model(&models.ImageGeneration{})
|
||||
|
||||
if dramaID != nil {
|
||||
query = query.Where("drama_id = ?", *dramaID)
|
||||
}
|
||||
|
||||
if sceneID != nil {
|
||||
query = query.Where("scene_id = ?", *sceneID)
|
||||
}
|
||||
|
||||
if storyboardID != nil {
|
||||
query = query.Where("storyboard_id = ?", *storyboardID)
|
||||
}
|
||||
|
||||
if frameType != "" {
|
||||
query = query.Where("frame_type = ?", frameType)
|
||||
}
|
||||
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var images []models.ImageGeneration
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&images).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return images, total, nil
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) DeleteImageGeneration(imageGenID uint) error {
|
||||
result := s.db.Where("id = ? ", imageGenID).Delete(&models.ImageGeneration{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("image generation not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) GenerateImagesForScene(sceneID string) ([]*models.ImageGeneration, error) {
|
||||
// 转换sceneID
|
||||
sid, err := strconv.ParseUint(sceneID, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid scene ID")
|
||||
}
|
||||
sceneIDUint := uint(sid)
|
||||
|
||||
var scene models.Scene
|
||||
if err := s.db.Where("id = ?", sceneIDUint).First(&scene).Error; err != nil {
|
||||
return nil, fmt.Errorf("scene not found")
|
||||
}
|
||||
|
||||
// 构建场景图片生成提示词
|
||||
prompt := scene.Prompt
|
||||
if prompt == "" {
|
||||
// 如果Prompt为空,使用Location和Time构建
|
||||
prompt = fmt.Sprintf("%s场景,%s", scene.Location, scene.Time)
|
||||
}
|
||||
|
||||
req := &GenerateImageRequest{
|
||||
SceneID: &sceneIDUint,
|
||||
DramaID: fmt.Sprintf("%d", scene.DramaID),
|
||||
ImageType: string(models.ImageTypeScene),
|
||||
Prompt: prompt,
|
||||
}
|
||||
|
||||
imageGen, err := s.GenerateImage(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return []*models.ImageGeneration{imageGen}, nil
|
||||
}
|
||||
|
||||
// BackgroundInfo 背景信息结构
|
||||
type BackgroundInfo struct {
|
||||
Location string `json:"location"`
|
||||
Time string `json:"time"`
|
||||
Atmosphere string `json:"atmosphere"`
|
||||
Prompt string `json:"prompt"`
|
||||
StoryboardNumbers []int `json:"storyboard_numbers"`
|
||||
SceneIDs []uint `json:"scene_ids"`
|
||||
StoryboardCount int `json:"scene_count"`
|
||||
}
|
||||
|
||||
func (s *ImageGenerationService) BatchGenerateImagesForEpisode(episodeID string) ([]*models.ImageGeneration, error) {
|
||||
var ep models.Episode
|
||||
if err := s.db.Preload("Drama").Where("id = ?", episodeID).First(&ep).Error; err != nil {
|
||||
return nil, fmt.Errorf("episode not found")
|
||||
}
|
||||
// 从数据库读取已保存的场景
|
||||
var scenes []models.Storyboard
|
||||
if err := s.db.Where("episode_id = ?", episodeID).Find(&scenes).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to get scenes: %w", err)
|
||||
}
|
||||
|
||||
backgrounds := s.extractUniqueBackgrounds(scenes)
|
||||
s.log.Infow("Extracted unique backgrounds",
|
||||
"episode_id", episodeID,
|
||||
"background_count", len(backgrounds))
|
||||
|
||||
// 为每个背景生成图片
|
||||
var results []*models.ImageGeneration
|
||||
for _, bg := range scenes {
|
||||
if bg.ImagePrompt == nil || *bg.ImagePrompt == "" {
|
||||
s.log.Warnw("Background has no prompt, skipping", "scene_id", bg.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
// 更新背景状态为处理中
|
||||
s.db.Model(bg).Update("status", "generating")
|
||||
|
||||
req := &GenerateImageRequest{
|
||||
StoryboardID: &bg.ID,
|
||||
DramaID: fmt.Sprintf("%d", ep.DramaID),
|
||||
Prompt: *bg.ImagePrompt,
|
||||
}
|
||||
|
||||
imageGen, err := s.GenerateImage(req)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to generate image for background",
|
||||
"scene_id", bg.ID,
|
||||
"location", bg.Location,
|
||||
"error", err)
|
||||
s.db.Model(bg).Update("status", "failed")
|
||||
continue
|
||||
}
|
||||
|
||||
s.log.Infow("Background image generation started",
|
||||
"scene_id", bg.ID,
|
||||
"image_gen_id", imageGen.ID,
|
||||
"location", bg.Location,
|
||||
"time", bg.Time)
|
||||
|
||||
results = append(results, imageGen)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetScencesForEpisode 获取项目的场景列表(项目级)
|
||||
func (s *ImageGenerationService) GetScencesForEpisode(episodeID string) ([]*models.Scene, error) {
|
||||
var episode models.Episode
|
||||
if err := s.db.Preload("Drama").Where("id = ?", episodeID).First(&episode).Error; err != nil {
|
||||
return nil, fmt.Errorf("episode not found")
|
||||
}
|
||||
|
||||
// 场景是项目级的,通过drama_id查询
|
||||
var scenes []*models.Scene
|
||||
if err := s.db.Where("drama_id = ?", episode.DramaID).Order("location ASC, time ASC").Find(&scenes).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to load scenes: %w", err)
|
||||
}
|
||||
|
||||
return scenes, nil
|
||||
}
|
||||
|
||||
// ExtractBackgroundsForEpisode 从剧本内容中提取场景并保存到项目级别数据库
|
||||
func (s *ImageGenerationService) ExtractBackgroundsForEpisode(episodeID string) ([]*models.Scene, error) {
|
||||
var episode models.Episode
|
||||
if err := s.db.Preload("Drama").Where("id = ?", episodeID).First(&episode).Error; err != nil {
|
||||
return nil, fmt.Errorf("episode not found")
|
||||
}
|
||||
|
||||
// 检查是否有剧本内容
|
||||
if episode.ScriptContent == nil || *episode.ScriptContent == "" {
|
||||
return nil, fmt.Errorf("剧本内容为空,无法提取场景")
|
||||
}
|
||||
|
||||
dramaID := episode.DramaID
|
||||
|
||||
// 使用AI从剧本内容中提取场景
|
||||
backgroundsInfo, err := s.extractBackgroundsFromScript(*episode.ScriptContent, dramaID)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to extract backgrounds from script", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保存到数据库(不涉及Storyboard关联,因为此时还没有生成分镜)
|
||||
var scenes []*models.Scene
|
||||
err = s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 先删除该章节的所有场景(实现重新提取覆盖功能)
|
||||
if err := tx.Where("episode_id = ?", episode.ID).Delete(&models.Scene{}).Error; err != nil {
|
||||
s.log.Errorw("Failed to delete old scenes", "error", err)
|
||||
return err
|
||||
}
|
||||
s.log.Infow("Deleted old scenes for re-extraction", "episode_id", episode.ID)
|
||||
|
||||
// 创建新提取的场景
|
||||
for _, bgInfo := range backgroundsInfo {
|
||||
// 保存新场景到数据库(章节级)
|
||||
episodeIDVal := episode.ID
|
||||
scene := &models.Scene{
|
||||
DramaID: dramaID,
|
||||
EpisodeID: &episodeIDVal,
|
||||
Location: bgInfo.Location,
|
||||
Time: bgInfo.Time,
|
||||
Prompt: bgInfo.Prompt,
|
||||
StoryboardCount: 1, // 默认为1
|
||||
Status: "pending",
|
||||
}
|
||||
if err := tx.Create(scene).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
scenes = append(scenes, scene)
|
||||
|
||||
s.log.Infow("Created new scene from script",
|
||||
"scene_id", scene.ID,
|
||||
"location", scene.Location,
|
||||
"time", scene.Time)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.log.Infow("Saved scenes to database",
|
||||
"episode_id", episodeID,
|
||||
"total_storyboards", len(episode.Storyboards),
|
||||
"unique_scenes", len(scenes))
|
||||
|
||||
return scenes, nil
|
||||
}
|
||||
|
||||
// extractBackgroundsFromScript 从剧本内容中使用AI提取场景信息
|
||||
func (s *ImageGenerationService) extractBackgroundsFromScript(scriptContent string, dramaID uint) ([]BackgroundInfo, error) {
|
||||
if scriptContent == "" {
|
||||
return []BackgroundInfo{}, nil
|
||||
}
|
||||
|
||||
// 获取AI客户端
|
||||
client, err := s.aiService.GetAIClient("text")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get AI client: %w", err)
|
||||
}
|
||||
|
||||
// 构建AI提示词
|
||||
prompt := fmt.Sprintf(`【任务】分析以下剧本内容,提取出所有需要的场景背景信息。
|
||||
|
||||
【剧本内容】
|
||||
%s
|
||||
|
||||
【要求】
|
||||
1. 识别剧本中所有不同的场景(地点+时间组合)
|
||||
2. 为每个场景生成详细的**中文**图片生成提示词(Prompt)
|
||||
3. **重要**:场景描述必须是**纯背景**,不能包含人物、角色、动作等元素
|
||||
4. Prompt要求:
|
||||
- **必须使用中文**,不能包含英文字符
|
||||
- 详细描述场景环境、建筑、物品、光线、氛围等
|
||||
- **禁止描述人物、角色、动作、对话等**
|
||||
- 适合AI图片生成模型使用
|
||||
- 风格统一为:电影感、细节丰富、动漫风格、高质量
|
||||
5. location、time、atmosphere和prompt字段都使用中文
|
||||
6. 提取场景的氛围描述(atmosphere)
|
||||
|
||||
【输出JSON格式】
|
||||
{
|
||||
"backgrounds": [
|
||||
{
|
||||
"location": "地点名称(中文)",
|
||||
"time": "时间描述(中文)",
|
||||
"atmosphere": "氛围描述(中文)",
|
||||
"prompt": "一个电影感的动漫风格纯背景场景,展现[地点描述]在[时间]的环境。画面呈现[环境细节、建筑、物品、光线等,不包含人物]。风格:细节丰富,高质量,氛围光照。情绪:[环境情绪描述]。"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
【示例】
|
||||
正确示例(注意:不包含人物):
|
||||
{
|
||||
"backgrounds": [
|
||||
{
|
||||
"location": "维修店内部",
|
||||
"time": "深夜",
|
||||
"atmosphere": "昏暗、孤独、工业感",
|
||||
"prompt": "一个电影感的动漫风格纯背景场景,展现凌乱的维修店内部在深夜的环境。昏暗的日光灯照射下,工作台上散落着各种扳手、螺丝刀和机械零件,墙上挂着油污斑斑的工具挂板和褪色海报,地面有油渍痕迹,角落堆放着废旧轮胎。风格:细节丰富,高质量,昏暗氛围。情绪:孤独、工业感。"
|
||||
},
|
||||
{
|
||||
"location": "城市街道",
|
||||
"time": "黄昏",
|
||||
"atmosphere": "温暖、繁忙、生活气息",
|
||||
"prompt": "一个电影感的动漫风格纯背景场景,展现繁华的城市街道在黄昏时分的环境。夕阳的余晖洒在街道的沥青路面上,两旁的商铺霓虹灯开始点亮,街边有自行车停靠架和公交站牌,远处高楼林立,天空呈现橙红色渐变。风格:细节丰富,高质量,温暖氛围。情绪:生活气息、繁忙。"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
【错误示例(包含人物,禁止)】:
|
||||
❌ "展现主角站在街道上的场景" - 包含人物
|
||||
❌ "人们匆匆而过" - 包含人物
|
||||
❌ "角色在房间里活动" - 包含人物
|
||||
|
||||
请严格按照JSON格式输出,确保所有字段都使用中文。`, scriptContent)
|
||||
|
||||
response, err := client.GenerateText(prompt, "", ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to extract backgrounds with AI", "error", err)
|
||||
return nil, fmt.Errorf("AI提取场景失败: %w", err)
|
||||
}
|
||||
s.log.Infow("AI backgrounds extraction response", "length", len(response))
|
||||
|
||||
// 解析JSON响应
|
||||
var result struct {
|
||||
Backgrounds []BackgroundInfo `json:"backgrounds"`
|
||||
}
|
||||
if err := utils.SafeParseAIJSON(response, &result); err != nil {
|
||||
s.log.Errorw("Failed to parse AI response", "error", err, "response", response[:minInt(500, len(response))])
|
||||
return nil, fmt.Errorf("解析AI响应失败: %w", err)
|
||||
}
|
||||
|
||||
s.log.Infow("Extracted backgrounds from script",
|
||||
"drama_id", dramaID,
|
||||
"backgrounds_count", len(result.Backgrounds))
|
||||
|
||||
return result.Backgrounds, nil
|
||||
}
|
||||
|
||||
// extractBackgroundsWithAI 使用AI智能分析场景并提取唯一背景
|
||||
func (s *ImageGenerationService) extractBackgroundsWithAI(storyboards []models.Storyboard) ([]BackgroundInfo, error) {
|
||||
if len(storyboards) == 0 {
|
||||
return []BackgroundInfo{}, nil
|
||||
}
|
||||
|
||||
// 构建场景列表文本,使用SceneNumber而不是索引
|
||||
var scenesText string
|
||||
for _, storyboard := range storyboards {
|
||||
location := ""
|
||||
if storyboard.Location != nil {
|
||||
location = *storyboard.Location
|
||||
}
|
||||
time := ""
|
||||
if storyboard.Time != nil {
|
||||
time = *storyboard.Time
|
||||
}
|
||||
action := ""
|
||||
if storyboard.Action != nil {
|
||||
action = *storyboard.Action
|
||||
}
|
||||
description := ""
|
||||
if storyboard.Description != nil {
|
||||
description = *storyboard.Description
|
||||
}
|
||||
|
||||
scenesText += fmt.Sprintf("镜头%d:\n地点: %s\n时间: %s\n动作: %s\n描述: %s\n\n",
|
||||
storyboard.StoryboardNumber, location, time, action, description)
|
||||
}
|
||||
|
||||
// 构建AI提示词
|
||||
prompt := fmt.Sprintf(`【任务】分析以下分镜头场景,提取出所有需要生成的唯一背景,并返回每个背景对应的场景编号。
|
||||
|
||||
【分镜头列表】
|
||||
%s
|
||||
|
||||
【要求】
|
||||
1. 合并相同或相似的场景背景(地点和时间相同或相近)
|
||||
2. 为每个唯一背景生成**中文**图片生成提示词(Prompt)
|
||||
3. Prompt要求:
|
||||
- **必须使用中文**,不能包含英文字符
|
||||
- 详细描述场景、时间、氛围、风格
|
||||
- 适合AI图片生成模型使用
|
||||
- 风格统一为:电影感、细节丰富、动漫风格、高质量
|
||||
4. **重要**:必须返回使用该背景的场景编号数组(scene_numbers)
|
||||
5. location、time和prompt字段都使用中文
|
||||
6. 每个场景都必须分配到某个背景,确保所有场景编号都被包含
|
||||
|
||||
【输出JSON格式】
|
||||
{
|
||||
"backgrounds": [
|
||||
{
|
||||
"location": "地点名称(中文)",
|
||||
"time": "时间描述(中文)",
|
||||
"prompt": "一个电影感的动漫风格背景,展现[地点描述]在[时间]的场景。画面呈现[细节描述]。风格:细节丰富,高质量,氛围光照。情绪:[情绪描述]。",
|
||||
"scene_numbers": [1, 2, 3]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
【示例】
|
||||
正确示例:
|
||||
{
|
||||
"backgrounds": [
|
||||
{
|
||||
"location": "维修店",
|
||||
"time": "深夜",
|
||||
"prompt": "一个电影感的动漫风格背景,展现凌乱的维修店内部在深夜的场景。昏暗的灯光下,工作台上散落着各种工具和零件,墙上挂着油污的海报。风格:细节丰富,高质量,昏暗氛围。情绪:孤独、工业感。",
|
||||
"scene_numbers": [1, 5, 6, 10, 15]
|
||||
},
|
||||
{
|
||||
"location": "城市全景",
|
||||
"time": "深夜·酸雨",
|
||||
"prompt": "一个电影感的动漫风格背景,展现沿海城市全景在深夜酸雨中的场景。霓虹灯在雨中模糊,高楼大厦笼罩在灰绿色的雨幕中,街道反射着五颜六色的光。风格:细节丰富,高质量,赛博朋克氛围。情绪:压抑、科幻、末世感。",
|
||||
"scene_numbers": [2, 7]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
请严格按照JSON格式输出,确保:
|
||||
1. prompt字段使用中文
|
||||
2. scene_numbers包含所有使用该背景的场景编号
|
||||
3. 所有场景都被分配到某个背景`, scenesText)
|
||||
|
||||
// 调用AI服务
|
||||
text, err := s.aiService.GenerateText(prompt, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("AI analysis failed: %w", err)
|
||||
}
|
||||
|
||||
// 解析AI返回的JSON
|
||||
var result struct {
|
||||
Scenes []struct {
|
||||
Location string `json:"location"`
|
||||
Time string `json:"time"`
|
||||
Prompt string `json:"prompt"`
|
||||
StoryboardNumber []int `json:"storyboard_number"`
|
||||
} `json:"backgrounds"`
|
||||
}
|
||||
|
||||
if err := utils.SafeParseAIJSON(text, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse AI response: %w", err)
|
||||
}
|
||||
|
||||
// 构建场景编号到场景ID的映射
|
||||
storyboardNumberToID := make(map[int]uint)
|
||||
for _, scene := range storyboards {
|
||||
storyboardNumberToID[scene.StoryboardNumber] = scene.ID
|
||||
}
|
||||
|
||||
// 转换为BackgroundInfo
|
||||
var backgrounds []BackgroundInfo
|
||||
for _, bg := range result.Scenes {
|
||||
// 将场景编号转换为场景ID
|
||||
var sceneIDs []uint
|
||||
for _, storyboardNum := range bg.StoryboardNumber {
|
||||
if storyboardID, ok := storyboardNumberToID[storyboardNum]; ok {
|
||||
sceneIDs = append(sceneIDs, storyboardID)
|
||||
}
|
||||
}
|
||||
|
||||
backgrounds = append(backgrounds, BackgroundInfo{
|
||||
Location: bg.Location,
|
||||
Time: bg.Time,
|
||||
Prompt: bg.Prompt,
|
||||
StoryboardNumbers: bg.StoryboardNumber,
|
||||
SceneIDs: sceneIDs,
|
||||
StoryboardCount: len(sceneIDs),
|
||||
})
|
||||
}
|
||||
|
||||
s.log.Infow("AI extracted backgrounds",
|
||||
"total_scenes", len(storyboards),
|
||||
"extracted_backgrounds", len(backgrounds))
|
||||
|
||||
return backgrounds, nil
|
||||
}
|
||||
|
||||
// extractUniqueBackgrounds 从分镜头中提取唯一背景(代码逻辑,作为AI提取的备份)
|
||||
func (s *ImageGenerationService) extractUniqueBackgrounds(scenes []models.Storyboard) []BackgroundInfo {
|
||||
backgroundMap := make(map[string]*BackgroundInfo)
|
||||
|
||||
for _, scene := range scenes {
|
||||
if scene.Location == nil || scene.Time == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 使用 location + time 作为唯一标识
|
||||
key := *scene.Location + "|" + *scene.Time
|
||||
|
||||
if bg, exists := backgroundMap[key]; exists {
|
||||
// 背景已存在,添加scene ID
|
||||
bg.SceneIDs = append(bg.SceneIDs, scene.ID)
|
||||
bg.StoryboardCount++
|
||||
} else {
|
||||
// 新背景 - 使用ImagePrompt构建背景提示词
|
||||
prompt := ""
|
||||
if scene.ImagePrompt != nil {
|
||||
prompt = *scene.ImagePrompt
|
||||
}
|
||||
backgroundMap[key] = &BackgroundInfo{
|
||||
Location: *scene.Location,
|
||||
Time: *scene.Time,
|
||||
Prompt: prompt,
|
||||
SceneIDs: []uint{scene.ID},
|
||||
StoryboardCount: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 转换为切片
|
||||
var backgrounds []BackgroundInfo
|
||||
for _, bg := range backgroundMap {
|
||||
backgrounds = append(backgrounds, *bg)
|
||||
}
|
||||
|
||||
return backgrounds
|
||||
}
|
||||
21
application/services/resource_transfer_service.go
Normal file
21
application/services/resource_transfer_service.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ResourceTransferService struct {
|
||||
db *gorm.DB
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewResourceTransferService(db *gorm.DB, log *logger.Logger) *ResourceTransferService {
|
||||
return &ResourceTransferService{
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// ResourceTransferService 现在只保留基本结构,MinIO相关功能已移除
|
||||
// 如需资源转存功能,请使用本地存储
|
||||
511
application/services/script_generation_service.go
Normal file
511
application/services/script_generation_service.go
Normal file
@@ -0,0 +1,511 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/pkg/ai"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"github.com/drama-generator/backend/pkg/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type ScriptGenerationService struct {
|
||||
db *gorm.DB
|
||||
aiService *AIService
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewScriptGenerationService(db *gorm.DB, log *logger.Logger) *ScriptGenerationService {
|
||||
return &ScriptGenerationService{
|
||||
db: db,
|
||||
aiService: NewAIService(db, log),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
type GenerateOutlineRequest struct {
|
||||
DramaID string `json:"drama_id" binding:"required"`
|
||||
Theme string `json:"theme" binding:"required,min=2,max=500"`
|
||||
Genre string `json:"genre"`
|
||||
Style string `json:"style"`
|
||||
Length int `json:"length"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
}
|
||||
|
||||
type GenerateCharactersRequest struct {
|
||||
DramaID string `json:"drama_id" binding:"required"`
|
||||
Outline string `json:"outline"`
|
||||
Count int `json:"count"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
}
|
||||
|
||||
type GenerateEpisodesRequest struct {
|
||||
DramaID string `json:"drama_id" binding:"required"`
|
||||
Outline string `json:"outline"`
|
||||
EpisodeCount int `json:"episode_count" binding:"required,min=1,max=100"`
|
||||
Temperature float64 `json:"temperature"`
|
||||
}
|
||||
|
||||
type OutlineResult struct {
|
||||
Title string `json:"title"`
|
||||
Summary string `json:"summary"`
|
||||
Genre string `json:"genre"`
|
||||
Tags []string `json:"tags"`
|
||||
Characters []CharacterOutline `json:"characters"`
|
||||
Episodes []EpisodeOutline `json:"episodes"`
|
||||
KeyScenes []string `json:"key_scenes"`
|
||||
}
|
||||
|
||||
type CharacterOutline struct {
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
Description string `json:"description"`
|
||||
Personality string `json:"personality"`
|
||||
Appearance string `json:"appearance"`
|
||||
}
|
||||
|
||||
type EpisodeOutline struct {
|
||||
EpisodeNumber int `json:"episode_number"`
|
||||
Title string `json:"title"`
|
||||
Summary string `json:"summary"`
|
||||
Scenes []string `json:"scenes"`
|
||||
Duration int `json:"duration"`
|
||||
}
|
||||
|
||||
func (s *ScriptGenerationService) GenerateOutline(req *GenerateOutlineRequest) (*OutlineResult, error) {
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ?", req.DramaID).First(&drama).Error; err != nil {
|
||||
return nil, fmt.Errorf("drama not found")
|
||||
}
|
||||
|
||||
systemPrompt := `你是专业短剧编剧。根据主题和剧集数量,创作完整的短剧大纲,规划好每一集的剧情走向。
|
||||
|
||||
要求:
|
||||
1. 剧情紧凑,矛盾冲突强烈,节奏快
|
||||
2. 必须规划好每一集的核心剧情
|
||||
3. 每集有明确冲突和转折点,集与集之间有连贯性和悬念
|
||||
|
||||
**重要:必须输出完整有效的JSON,确保所有字段完整,特别是episodes数组必须完整闭合!**
|
||||
|
||||
JSON格式(紧凑,summary和episodes字段必须完整):
|
||||
{"title":"剧名","summary":"200-250字剧情概述,包含故事背景、主要矛盾、核心冲突、完整走向","genre":"类型","tags":["标签1","标签2","标签3"],"episodes":[{"episode_number":1,"title":"标题","summary":"80字剧情概要"},{"episode_number":2,"title":"标题","summary":"80字剧情概要"}],"key_scenes":["场景1","场景2","场景3"]}
|
||||
|
||||
关键要求:
|
||||
- summary控制在200-250字,简洁清晰
|
||||
- episodes必须生成用户要求的完整集数
|
||||
- 每集summary控制在80字左右
|
||||
- 确保JSON完整闭合,不要截断
|
||||
- 不要添加任何JSON外的文字说明`
|
||||
|
||||
userPrompt := fmt.Sprintf(`请为以下主题创作短剧大纲:
|
||||
|
||||
主题:%s`, req.Theme)
|
||||
|
||||
if req.Genre != "" {
|
||||
userPrompt += fmt.Sprintf("\n类型偏好:%s", req.Genre)
|
||||
}
|
||||
|
||||
if req.Style != "" {
|
||||
userPrompt += fmt.Sprintf("\n风格要求:%s", req.Style)
|
||||
}
|
||||
|
||||
length := req.Length
|
||||
if length == 0 {
|
||||
length = 5
|
||||
}
|
||||
userPrompt += fmt.Sprintf("\n剧集数量:%d集", length)
|
||||
userPrompt += fmt.Sprintf("\n\n**重要:必须在episodes数组中规划完整的%d集剧情,每集都要有明确的故事内容!**", length)
|
||||
|
||||
temperature := req.Temperature
|
||||
if temperature == 0 {
|
||||
temperature = 0.8
|
||||
}
|
||||
|
||||
// 调整token限制:基础2000 + 每集约150 tokens(包含80-100字概要)
|
||||
maxTokens := 2000 + (length * 150)
|
||||
if maxTokens > 8000 {
|
||||
maxTokens = 8000
|
||||
}
|
||||
|
||||
s.log.Infow("Generating outline with episodes",
|
||||
"episode_count", length,
|
||||
"max_tokens", maxTokens)
|
||||
|
||||
text, err := s.aiService.GenerateText(
|
||||
userPrompt,
|
||||
systemPrompt,
|
||||
ai.WithTemperature(temperature),
|
||||
ai.WithMaxTokens(maxTokens),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to generate outline", "error", err)
|
||||
return nil, fmt.Errorf("生成失败: %w", err)
|
||||
}
|
||||
|
||||
s.log.Infow("AI response received", "length", len(text), "preview", text[:minInt(200, len(text))])
|
||||
|
||||
var result OutlineResult
|
||||
if err := utils.SafeParseAIJSON(text, &result); err != nil {
|
||||
s.log.Errorw("Failed to parse outline JSON", "error", err, "raw_response", text[:minInt(500, len(text))])
|
||||
return nil, fmt.Errorf("解析 AI 返回结果失败: %w", err)
|
||||
}
|
||||
|
||||
// 将Tags转换为JSON格式存储
|
||||
tagsJSON, err := json.Marshal(result.Tags)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to marshal tags", "error", err)
|
||||
tagsJSON = []byte("[]")
|
||||
}
|
||||
|
||||
if err := s.db.Model(&drama).Updates(map[string]interface{}{
|
||||
"title": result.Title,
|
||||
"description": result.Summary,
|
||||
"genre": result.Genre,
|
||||
"tags": tagsJSON,
|
||||
}).Error; err != nil {
|
||||
s.log.Errorw("Failed to update drama", "error", err)
|
||||
}
|
||||
|
||||
s.log.Infow("Outline generated", "drama_id", req.DramaID)
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequest) ([]models.Character, error) {
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", req.DramaID).First(&drama).Error; err != nil {
|
||||
return nil, fmt.Errorf("drama not found")
|
||||
}
|
||||
|
||||
count := req.Count
|
||||
if count == 0 {
|
||||
count = 5
|
||||
}
|
||||
|
||||
systemPrompt := `你是一个专业的角色分析师,擅长从剧本中提取和分析角色信息。
|
||||
|
||||
你的任务是根据提供的剧本内容,提取并整理剧中出现的所有角色的详细设定。
|
||||
|
||||
要求:
|
||||
1. 仔细阅读剧本,识别所有出现的角色
|
||||
2. 根据剧本中的对话、行为和描述,总结角色的性格特点
|
||||
3. 提取角色在剧本中的关键信息:背景、动机、目标、关系等
|
||||
4. 角色之间的关系必须基于剧本中的实际描述
|
||||
5. 外貌描述必须极其详细,如果剧本中有描述则使用,如果没有则根据角色设定合理推断,便于AI绘画生成角色形象
|
||||
6. 优先提取主要角色和重要配角,次要角色可以简略
|
||||
|
||||
请严格按照以下 JSON 格式输出,不要添加任何其他文字:
|
||||
|
||||
{
|
||||
"characters": [
|
||||
{
|
||||
"name": "角色名",
|
||||
"role": "主角/重要配角/配角",
|
||||
"description": "角色背景和简介(200-300字,包括:出身背景、成长经历、核心动机、与其他角色的关系、在故事中的作用)",
|
||||
"personality": "性格特点(详细描述,100-150字,包括:主要性格特征、行为习惯、价值观、优点缺点、情绪表达方式、对待他人的态度等)",
|
||||
"appearance": "外貌描述(极其详细,150-200字,必须包括:确切年龄、精确身高、体型身材、肤色质感、发型发色发长、眼睛颜色形状、面部特征(如眉毛、鼻子、嘴唇)、着装风格、服装颜色材质、配饰细节、标志性特征、整体气质风格等,描述要具体到可以直接用于AI绘画)",
|
||||
"voice_style": "说话风格和语气特点(详细描述,50-80字,包括:语速语调、用词习惯、口头禅、说话时的情绪特征等)"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
注意:
|
||||
- 必须基于剧本内容提取角色,不要凭空创作
|
||||
- 优先提取主要角色和重要配角,数量根据剧本实际情况确定
|
||||
- description、personality、appearance、voice_style都必须详细描述,字数要充足
|
||||
- appearance外貌描述是重中之重,必须极其详细具体,要能让AI准确生成角色形象
|
||||
- 如果剧本中角色信息不完整,可以根据角色设定合理补充,但要符合剧本整体风格`
|
||||
|
||||
outlineText := req.Outline
|
||||
if outlineText == "" {
|
||||
outlineText = fmt.Sprintf("剧名:%s\n简介:%s\n类型:%s", drama.Title, drama.Description, drama.Genre)
|
||||
}
|
||||
|
||||
userPrompt := fmt.Sprintf(`剧本内容:
|
||||
%s
|
||||
|
||||
请从剧本中提取并整理最多 %d 个主要角色的详细设定。`, outlineText, count)
|
||||
|
||||
temperature := req.Temperature
|
||||
if temperature == 0 {
|
||||
temperature = 0.7
|
||||
}
|
||||
|
||||
text, err := s.aiService.GenerateText(
|
||||
userPrompt,
|
||||
systemPrompt,
|
||||
ai.WithTemperature(temperature),
|
||||
ai.WithMaxTokens(3000),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to generate characters", "error", err)
|
||||
return nil, fmt.Errorf("生成失败: %w", err)
|
||||
}
|
||||
|
||||
s.log.Infow("AI response received", "length", len(text), "preview", text[:minInt(200, len(text))])
|
||||
|
||||
var result struct {
|
||||
Characters []struct {
|
||||
Name string `json:"name"`
|
||||
Role string `json:"role"`
|
||||
Description string `json:"description"`
|
||||
Personality string `json:"personality"`
|
||||
Appearance string `json:"appearance"`
|
||||
VoiceStyle string `json:"voice_style"`
|
||||
} `json:"characters"`
|
||||
}
|
||||
|
||||
if err := utils.SafeParseAIJSON(text, &result); err != nil {
|
||||
s.log.Errorw("Failed to parse characters JSON", "error", err, "raw_response", text[:minInt(500, len(text))])
|
||||
return nil, fmt.Errorf("解析 AI 返回结果失败: %w", err)
|
||||
}
|
||||
|
||||
var characters []models.Character
|
||||
for _, char := range result.Characters {
|
||||
// 检查角色是否已存在
|
||||
var existingChar models.Character
|
||||
err := s.db.Where("drama_id = ? AND name = ?", req.DramaID, char.Name).First(&existingChar).Error
|
||||
if err == nil {
|
||||
// 角色已存在,直接使用已存在的角色,不覆盖
|
||||
s.log.Infow("Character already exists, skipping", "drama_id", req.DramaID, "name", char.Name)
|
||||
characters = append(characters, existingChar)
|
||||
continue
|
||||
}
|
||||
|
||||
// 角色不存在,创建新角色
|
||||
dramaID, _ := strconv.ParseUint(req.DramaID, 10, 32)
|
||||
character := models.Character{
|
||||
DramaID: uint(dramaID),
|
||||
Name: char.Name,
|
||||
Role: &char.Role,
|
||||
Description: &char.Description,
|
||||
Personality: &char.Personality,
|
||||
Appearance: &char.Appearance,
|
||||
VoiceStyle: &char.VoiceStyle,
|
||||
}
|
||||
|
||||
if err := s.db.Create(&character).Error; err != nil {
|
||||
s.log.Errorw("Failed to create character", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
characters = append(characters, character)
|
||||
}
|
||||
|
||||
s.log.Infow("Characters generated", "drama_id", req.DramaID, "total_count", len(characters), "new_count", len(characters))
|
||||
return characters, nil
|
||||
}
|
||||
|
||||
func (s *ScriptGenerationService) GenerateEpisodes(req *GenerateEpisodesRequest) ([]models.Episode, error) {
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", req.DramaID).First(&drama).Error; err != nil {
|
||||
return nil, fmt.Errorf("drama not found")
|
||||
}
|
||||
|
||||
// 获取角色信息
|
||||
var characters []models.Character
|
||||
s.db.Where("drama_id = ?", req.DramaID).Find(&characters)
|
||||
|
||||
var characterList string
|
||||
if len(characters) > 0 {
|
||||
characterList = "\n角色设定:\n"
|
||||
for _, char := range characters {
|
||||
characterList += fmt.Sprintf("- %s", char.Name)
|
||||
if char.Role != nil {
|
||||
characterList += fmt.Sprintf("(%s)", *char.Role)
|
||||
}
|
||||
if char.Description != nil {
|
||||
characterList += fmt.Sprintf(":%s", *char.Description)
|
||||
}
|
||||
if char.Personality != nil {
|
||||
characterList += fmt.Sprintf(" | 性格:%s", *char.Personality)
|
||||
}
|
||||
characterList += "\n"
|
||||
}
|
||||
} else {
|
||||
characterList = "\n(注意:尚未设定角色,请根据大纲创作合理的角色出场)\n"
|
||||
}
|
||||
|
||||
systemPrompt := `你是一个专业的短剧编剧。你擅长根据分集规划创作详细的剧情内容。
|
||||
|
||||
你的任务是根据大纲中的分集规划,将每一集的概要扩展为详细的剧情叙述。每集约180秒(3分钟),需要充实的内容。
|
||||
|
||||
工作流程:
|
||||
1. 大纲中已提供每集的剧情规划(80-100字概要)
|
||||
2. 你需要将每集概要扩展为400-500字的详细剧情叙述
|
||||
3. 严格按照分集规划的数量和走向展开,不能遗漏任何一集
|
||||
|
||||
详细要求:
|
||||
1. script_content用400-500字详细叙述,包括:
|
||||
- 具体场景和环境描写
|
||||
- 角色的行动、对话要点、情绪变化
|
||||
- 冲突的产生过程和激化细节
|
||||
- 关键情节点和转折
|
||||
- 为下一集埋下的伏笔
|
||||
2. 每集有明确的冲突和转折点
|
||||
3. 集与集之间有连贯性和悬念
|
||||
4. 充分展现角色性格和关系演变
|
||||
5. 内容详实,足以支撑180秒时长
|
||||
|
||||
JSON格式(紧凑):
|
||||
{"episodes":[{"episode_number":1,"title":"标题","description":"简短梗概","script_content":"400-500字详细剧情叙述","duration":210}]}
|
||||
|
||||
格式说明:
|
||||
1. script_content为叙述文,不是场景对话格式
|
||||
2. 每集包含开场铺垫、冲突发展、高潮转折、结局悬念
|
||||
3. duration根据剧情复杂度设置在150-300秒
|
||||
|
||||
关键要求:
|
||||
- 大纲规划了几集就必须生成几集
|
||||
- 严格按照分集规划的故事线展开
|
||||
- 每一集都要有完整的400-500字详细内容
|
||||
- 绝对不能遗漏任何一集`
|
||||
|
||||
outlineText := req.Outline
|
||||
if outlineText == "" {
|
||||
outlineText = fmt.Sprintf("剧名:%s\n简介:%s\n类型:%s", drama.Title, drama.Description, drama.Genre)
|
||||
}
|
||||
|
||||
userPrompt := fmt.Sprintf(`剧本大纲:
|
||||
%s
|
||||
%s
|
||||
请基于以上大纲和角色,创作 %d 集的详细剧本。
|
||||
|
||||
**重要要求:**
|
||||
- 必须生成完整的 %d 集,从第1集到第%d集,不能遗漏
|
||||
- 每集约3-5分钟(150-300秒)
|
||||
- 每集的duration字段要根据剧本内容长度合理设置,不要都设置为同一个值
|
||||
- 返回的JSON中episodes数组必须包含 %d 个元素`, outlineText, characterList, req.EpisodeCount, req.EpisodeCount, req.EpisodeCount, req.EpisodeCount)
|
||||
|
||||
temperature := req.Temperature
|
||||
if temperature == 0 {
|
||||
temperature = 0.7
|
||||
}
|
||||
|
||||
// 根据剧集数量调整token限制
|
||||
// 模型支持128k上下文,每集400-500字约需800-1000 tokens(包含JSON结构)
|
||||
baseTokens := 3000 // 基础(系统提示+角色列表+大纲)
|
||||
perEpisodeTokens := 900 // 每集约900 tokens(支持400-500字详细内容)
|
||||
maxTokens := baseTokens + (req.EpisodeCount * perEpisodeTokens)
|
||||
|
||||
// 128k上下文,可以设置较大的token限制
|
||||
// 10集约12000 tokens,20集约21000 tokens,都在安全范围内
|
||||
if maxTokens > 32000 {
|
||||
maxTokens = 32000 // 保守限制在32k,留足够空间
|
||||
}
|
||||
|
||||
s.log.Infow("Generating episodes with token limit",
|
||||
"episode_count", req.EpisodeCount,
|
||||
"max_tokens", maxTokens,
|
||||
"estimated_per_episode", perEpisodeTokens)
|
||||
|
||||
text, err := s.aiService.GenerateText(
|
||||
userPrompt,
|
||||
systemPrompt,
|
||||
ai.WithTemperature(0.8),
|
||||
ai.WithMaxTokens(maxTokens),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to generate episodes", "error", err)
|
||||
return nil, fmt.Errorf("生成失败: %w", err)
|
||||
}
|
||||
|
||||
s.log.Infow("AI response received", "length", len(text), "preview", text[:minInt(200, len(text))])
|
||||
|
||||
var result struct {
|
||||
Episodes []struct {
|
||||
EpisodeNumber int `json:"episode_number"`
|
||||
Title string `json:"title"`
|
||||
Description string `json:"description"`
|
||||
ScriptContent string `json:"script_content"`
|
||||
Duration int `json:"duration"`
|
||||
} `json:"episodes"`
|
||||
}
|
||||
|
||||
if err := utils.SafeParseAIJSON(text, &result); err != nil {
|
||||
s.log.Errorw("Failed to parse episodes JSON", "error", err, "raw_response", text[:minInt(500, len(text))])
|
||||
return nil, fmt.Errorf("解析 AI 返回结果失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查生成的集数是否符合要求
|
||||
if len(result.Episodes) < req.EpisodeCount {
|
||||
s.log.Warnw("AI generated fewer episodes than requested",
|
||||
"requested", req.EpisodeCount,
|
||||
"generated", len(result.Episodes))
|
||||
}
|
||||
|
||||
// 记录每集的详细信息
|
||||
for i, ep := range result.Episodes {
|
||||
s.log.Infow("Episode parsed from AI",
|
||||
"index", i,
|
||||
"episode_number", ep.EpisodeNumber,
|
||||
"title", ep.Title,
|
||||
"description_length", len(ep.Description),
|
||||
"script_content_length", len(ep.ScriptContent),
|
||||
"duration", ep.Duration)
|
||||
}
|
||||
|
||||
var episodes []models.Episode
|
||||
for _, ep := range result.Episodes {
|
||||
duration := ep.Duration
|
||||
if duration == 0 {
|
||||
// AI未返回时长时使用默认值
|
||||
duration = 180
|
||||
s.log.Warnw("Episode duration not provided by AI, using default",
|
||||
"episode_number", ep.EpisodeNumber,
|
||||
"default_duration", 180)
|
||||
} else {
|
||||
s.log.Infow("Episode duration from AI",
|
||||
"episode_number", ep.EpisodeNumber,
|
||||
"duration", duration)
|
||||
}
|
||||
|
||||
// 记录即将保存的数据
|
||||
s.log.Infow("Creating episode in database",
|
||||
"episode_number", ep.EpisodeNumber,
|
||||
"title", ep.Title,
|
||||
"script_content_length", len(ep.ScriptContent),
|
||||
"script_content_empty", ep.ScriptContent == "")
|
||||
|
||||
dramaID, err := strconv.ParseUint(req.DramaID, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid drama ID")
|
||||
}
|
||||
|
||||
episode := models.Episode{
|
||||
DramaID: uint(dramaID),
|
||||
EpisodeNum: ep.EpisodeNumber,
|
||||
Title: ep.Title,
|
||||
Description: &ep.Description,
|
||||
ScriptContent: &ep.ScriptContent,
|
||||
Duration: duration,
|
||||
Status: "draft",
|
||||
}
|
||||
|
||||
if err := s.db.Create(&episode).Error; err != nil {
|
||||
s.log.Errorw("Failed to create episode", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
episodes = append(episodes, episode)
|
||||
}
|
||||
|
||||
s.log.Infow("Episodes generated", "drama_id", req.DramaID, "count", len(episodes))
|
||||
return episodes, nil
|
||||
}
|
||||
|
||||
// GenerateScenesForEpisode 已废弃,使用 StoryboardService.GenerateStoryboard 替代
|
||||
// ParseScript 已废弃,使用 GenerateCharacters 替代
|
||||
|
||||
// minInt 返回两个整数中较小的一个
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
395
application/services/storyboard_composition_service.go
Normal file
395
application/services/storyboard_composition_service.go
Normal file
@@ -0,0 +1,395 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
models "github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type StoryboardCompositionService struct {
|
||||
db *gorm.DB
|
||||
log *logger.Logger
|
||||
imageGen *ImageGenerationService
|
||||
}
|
||||
|
||||
func NewStoryboardCompositionService(db *gorm.DB, log *logger.Logger, imageGen *ImageGenerationService) *StoryboardCompositionService {
|
||||
return &StoryboardCompositionService{
|
||||
db: db,
|
||||
log: log,
|
||||
imageGen: imageGen,
|
||||
}
|
||||
}
|
||||
|
||||
type SceneCharacterInfo struct {
|
||||
ID uint `json:"id"`
|
||||
Name string `json:"name"`
|
||||
ImageURL *string `json:"image_url,omitempty"`
|
||||
}
|
||||
|
||||
type SceneBackgroundInfo struct {
|
||||
ID uint `json:"id"`
|
||||
Location string `json:"location"`
|
||||
Time string `json:"time"`
|
||||
ImageURL *string `json:"image_url,omitempty"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
type SceneCompositionInfo struct {
|
||||
ID uint `json:"id"`
|
||||
StoryboardNumber int `json:"storyboard_number"`
|
||||
Title *string `json:"title"`
|
||||
Description *string `json:"description"`
|
||||
Location *string `json:"location"`
|
||||
Time *string `json:"time"`
|
||||
Duration int `json:"duration"`
|
||||
Dialogue *string `json:"dialogue"`
|
||||
Action *string `json:"action"`
|
||||
Atmosphere *string `json:"atmosphere"`
|
||||
ImagePrompt *string `json:"image_prompt,omitempty"`
|
||||
VideoPrompt *string `json:"video_prompt,omitempty"`
|
||||
Characters []SceneCharacterInfo `json:"characters"`
|
||||
Background *SceneBackgroundInfo `json:"background"`
|
||||
SceneID *uint `json:"scene_id"`
|
||||
ComposedImage *string `json:"composed_image,omitempty"`
|
||||
VideoURL *string `json:"video_url,omitempty"`
|
||||
ImageGenerationID *uint `json:"image_generation_id,omitempty"`
|
||||
ImageGenerationStatus *string `json:"image_generation_status,omitempty"`
|
||||
VideoGenerationID *uint `json:"video_generation_id,omitempty"`
|
||||
VideoGenerationStatus *string `json:"video_generation_status,omitempty"`
|
||||
}
|
||||
|
||||
func (s *StoryboardCompositionService) GetScenesForEpisode(episodeID string) ([]SceneCompositionInfo, error) {
|
||||
// 验证权限
|
||||
var episode models.Episode
|
||||
err := s.db.Preload("Drama").Where("id = ?", episodeID).First(&episode).Error
|
||||
if err != nil {
|
||||
s.log.Errorw("Episode not found", "episode_id", episodeID, "error", err)
|
||||
return nil, fmt.Errorf("episode not found")
|
||||
}
|
||||
|
||||
s.log.Infow("GetScenesForEpisode auth check",
|
||||
"episode_id", episodeID,
|
||||
"drama_id", episode.DramaID)
|
||||
|
||||
// 获取分镜列表
|
||||
var storyboards []models.Storyboard
|
||||
if err := s.db.Where("episode_id = ?", episodeID).
|
||||
Preload("Characters").
|
||||
Order("storyboard_number ASC").
|
||||
Find(&storyboards).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to load storyboards: %w", err)
|
||||
}
|
||||
|
||||
// 获取所有角色(用于匹配角色信息)
|
||||
var characters []models.Character
|
||||
if err := s.db.Where("drama_id = ?", episode.DramaID).Find(&characters).Error; err != nil {
|
||||
s.log.Warnw("Failed to load characters", "error", err)
|
||||
}
|
||||
|
||||
// 创建角色ID到角色信息的映射
|
||||
charIDToInfo := make(map[uint]*models.Character)
|
||||
for i := range characters {
|
||||
charIDToInfo[characters[i].ID] = &characters[i]
|
||||
}
|
||||
|
||||
// 获取所有场景ID
|
||||
var sceneIDs []uint
|
||||
for _, storyboard := range storyboards {
|
||||
if storyboard.SceneID != nil {
|
||||
sceneIDs = append(sceneIDs, *storyboard.SceneID)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量获取场景信息
|
||||
var scenes []models.Scene
|
||||
sceneMap := make(map[uint]*models.Scene)
|
||||
if len(sceneIDs) > 0 {
|
||||
if err := s.db.Where("id IN ?", sceneIDs).Find(&scenes).Error; err == nil {
|
||||
for i := range scenes {
|
||||
sceneMap[scenes[i].ID] = &scenes[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 获取分镜的合成图片(从 image_generations 表)
|
||||
storyboardIDs := make([]uint, len(storyboards))
|
||||
for i, storyboard := range storyboards {
|
||||
storyboardIDs[i] = storyboard.ID
|
||||
}
|
||||
|
||||
imageGenMap := make(map[uint]string) // storyboard_id -> image_url
|
||||
imageGenTaskMap := make(map[uint]*models.ImageGeneration) // storyboard_id -> processing task
|
||||
if len(storyboardIDs) > 0 {
|
||||
var imageGens []models.ImageGeneration
|
||||
// 查询已完成的图片生成记录,每个镜头只取最新的一条
|
||||
if err := s.db.Where("storyboard_id IN ? AND status = ?", storyboardIDs, models.ImageStatusCompleted).
|
||||
Order("created_at DESC").
|
||||
Find(&imageGens).Error; err == nil {
|
||||
// 为每个镜头保留最新的一条记录
|
||||
for _, ig := range imageGens {
|
||||
if ig.StoryboardID != nil {
|
||||
if _, exists := imageGenMap[*ig.StoryboardID]; !exists {
|
||||
if ig.ImageURL != nil {
|
||||
imageGenMap[*ig.StoryboardID] = *ig.ImageURL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查询进行中的图片生成任务
|
||||
var processingImageGens []models.ImageGeneration
|
||||
if err := s.db.Where("storyboard_id IN ? AND status = ?", storyboardIDs, models.ImageStatusProcessing).
|
||||
Order("created_at DESC").
|
||||
Find(&processingImageGens).Error; err == nil {
|
||||
for _, ig := range processingImageGens {
|
||||
if ig.StoryboardID != nil {
|
||||
if _, exists := imageGenTaskMap[*ig.StoryboardID]; !exists {
|
||||
igCopy := ig
|
||||
imageGenTaskMap[*ig.StoryboardID] = &igCopy
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 批量查询进行中的视频生成任务
|
||||
videoGenTaskMap := make(map[uint]*models.VideoGeneration) // storyboard_id -> processing task
|
||||
if len(storyboardIDs) > 0 {
|
||||
var processingVideoGens []models.VideoGeneration
|
||||
if err := s.db.Where("scene_id IN ? AND status = ?", storyboardIDs, models.VideoStatusProcessing).
|
||||
Order("created_at DESC").
|
||||
Find(&processingVideoGens).Error; err == nil {
|
||||
for _, vg := range processingVideoGens {
|
||||
if vg.StoryboardID != nil {
|
||||
if _, exists := videoGenTaskMap[*vg.StoryboardID]; !exists {
|
||||
vgCopy := vg
|
||||
videoGenTaskMap[*vg.StoryboardID] = &vgCopy
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构建返回结果
|
||||
var result []SceneCompositionInfo
|
||||
for _, storyboard := range storyboards {
|
||||
storyboardInfo := SceneCompositionInfo{
|
||||
ID: storyboard.ID,
|
||||
StoryboardNumber: storyboard.StoryboardNumber,
|
||||
Title: storyboard.Title,
|
||||
Description: storyboard.Description,
|
||||
Location: storyboard.Location,
|
||||
Time: storyboard.Time,
|
||||
Duration: storyboard.Duration,
|
||||
Action: storyboard.Action,
|
||||
Dialogue: storyboard.Dialogue,
|
||||
Atmosphere: storyboard.Atmosphere,
|
||||
ImagePrompt: storyboard.ImagePrompt,
|
||||
VideoPrompt: storyboard.VideoPrompt,
|
||||
SceneID: storyboard.SceneID,
|
||||
}
|
||||
|
||||
// 直接使用关联的角色信息
|
||||
if len(storyboard.Characters) > 0 {
|
||||
for _, char := range storyboard.Characters {
|
||||
storyboardChar := SceneCharacterInfo{
|
||||
ID: char.ID,
|
||||
Name: char.Name,
|
||||
ImageURL: char.ImageURL,
|
||||
}
|
||||
storyboardInfo.Characters = append(storyboardInfo.Characters, storyboardChar)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加场景信息
|
||||
if storyboard.SceneID != nil {
|
||||
if scene, ok := sceneMap[*storyboard.SceneID]; ok {
|
||||
storyboardInfo.Background = &SceneBackgroundInfo{
|
||||
ID: scene.ID,
|
||||
Location: scene.Location,
|
||||
Time: scene.Time,
|
||||
ImageURL: scene.ImageURL,
|
||||
Status: scene.Status,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加合成图片
|
||||
if imageURL, ok := imageGenMap[storyboard.ID]; ok {
|
||||
storyboardInfo.ComposedImage = &imageURL
|
||||
}
|
||||
|
||||
// 添加视频URL
|
||||
if storyboard.VideoURL != nil {
|
||||
storyboardInfo.VideoURL = storyboard.VideoURL
|
||||
}
|
||||
|
||||
// 添加进行中的图片生成任务信息
|
||||
if imageTask, ok := imageGenTaskMap[storyboard.ID]; ok {
|
||||
storyboardInfo.ImageGenerationID = &imageTask.ID
|
||||
statusStr := string(imageTask.Status)
|
||||
storyboardInfo.ImageGenerationStatus = &statusStr
|
||||
}
|
||||
|
||||
// 添加进行中的视频生成任务信息
|
||||
if videoTask, ok := videoGenTaskMap[storyboard.ID]; ok {
|
||||
storyboardInfo.VideoGenerationID = &videoTask.ID
|
||||
statusStr := string(videoTask.Status)
|
||||
storyboardInfo.VideoGenerationStatus = &statusStr
|
||||
}
|
||||
|
||||
result = append(result, storyboardInfo)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type UpdateSceneRequest struct {
|
||||
SceneID *uint `json:"scene_id"`
|
||||
Characters []uint `json:"characters"` // 改为存储角色ID数组
|
||||
Location *string `json:"location"`
|
||||
Time *string `json:"time"`
|
||||
Action *string `json:"action"`
|
||||
Dialogue *string `json:"dialogue"`
|
||||
Description *string `json:"description"`
|
||||
Duration *int `json:"duration"`
|
||||
ImagePrompt *string `json:"image_prompt"`
|
||||
VideoPrompt *string `json:"video_prompt"`
|
||||
}
|
||||
|
||||
func (s *StoryboardCompositionService) UpdateScene(sceneID string, req *UpdateSceneRequest) error {
|
||||
// 获取分镜并验证权限
|
||||
var storyboard models.Storyboard
|
||||
err := s.db.Preload("Episode.Drama").Where("id = ?", sceneID).First(&storyboard).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("scene not found")
|
||||
}
|
||||
|
||||
// 构建更新数据
|
||||
updates := make(map[string]interface{})
|
||||
|
||||
// 更新背景ID
|
||||
if req.SceneID != nil {
|
||||
updates["scene_id"] = req.SceneID
|
||||
}
|
||||
|
||||
// 更新角色列表(直接存储ID数组)
|
||||
if req.Characters != nil {
|
||||
charactersJSON, err := json.Marshal(req.Characters)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to serialize characters: %w", err)
|
||||
}
|
||||
updates["characters"] = charactersJSON
|
||||
}
|
||||
|
||||
// 更新场景信息字段
|
||||
if req.Location != nil {
|
||||
updates["location"] = req.Location
|
||||
}
|
||||
if req.Time != nil {
|
||||
updates["time"] = req.Time
|
||||
}
|
||||
if req.Action != nil {
|
||||
updates["action"] = req.Action
|
||||
}
|
||||
if req.Dialogue != nil {
|
||||
updates["dialogue"] = req.Dialogue
|
||||
}
|
||||
if req.Description != nil {
|
||||
updates["description"] = req.Description
|
||||
}
|
||||
if req.Duration != nil {
|
||||
updates["duration"] = *req.Duration
|
||||
}
|
||||
if req.ImagePrompt != nil {
|
||||
updates["image_prompt"] = req.ImagePrompt
|
||||
}
|
||||
if req.VideoPrompt != nil {
|
||||
updates["video_prompt"] = req.VideoPrompt
|
||||
}
|
||||
|
||||
// 执行更新
|
||||
if len(updates) > 0 {
|
||||
if err := s.db.Model(&models.Storyboard{}).Where("id = ?", sceneID).Updates(updates).Error; err != nil {
|
||||
return fmt.Errorf("failed to update scene: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Infow("Scene updated", "scene_id", sceneID, "updates", updates)
|
||||
return nil
|
||||
}
|
||||
|
||||
type GenerateSceneImageRequest struct {
|
||||
SceneID uint `json:"scene_id"`
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
func (s *StoryboardCompositionService) GenerateSceneImage(req *GenerateSceneImageRequest) (*models.ImageGeneration, error) {
|
||||
// 获取场景并验证权限
|
||||
var scene models.Scene
|
||||
err := s.db.Where("id = ?", req.SceneID).First(&scene).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("scene not found")
|
||||
}
|
||||
|
||||
// 验证权限:通过DramaID查询Drama
|
||||
var drama models.Drama
|
||||
if err := s.db.Where("id = ? ", scene.DramaID).First(&drama).Error; err != nil {
|
||||
return nil, fmt.Errorf("unauthorized")
|
||||
}
|
||||
|
||||
// 构建场景图片生成提示词
|
||||
prompt := req.Prompt
|
||||
if prompt == "" {
|
||||
// 使用场景的Prompt字段
|
||||
prompt = scene.Prompt
|
||||
if prompt == "" {
|
||||
// 如果Prompt为空,使用Location和Time构建
|
||||
prompt = fmt.Sprintf("%s场景,%s", scene.Location, scene.Time)
|
||||
}
|
||||
s.log.Infow("Using scene prompt", "scene_id", req.SceneID, "prompt", prompt)
|
||||
}
|
||||
|
||||
// 使用imageGen服务直接生成
|
||||
if s.imageGen != nil {
|
||||
genReq := &GenerateImageRequest{
|
||||
SceneID: &req.SceneID,
|
||||
DramaID: fmt.Sprintf("%d", scene.DramaID),
|
||||
ImageType: string(models.ImageTypeScene),
|
||||
Prompt: prompt,
|
||||
Model: req.Model, // 使用用户指定的模型
|
||||
Size: "2560x1440", // 3,686,400像素,满足doubao模型最低要求(16:9比例)
|
||||
Quality: "standard",
|
||||
}
|
||||
imageGen, err := s.imageGen.GenerateImage(genReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate image: %w", err)
|
||||
}
|
||||
|
||||
// 更新场景的image_url
|
||||
if imageGen.ImageURL != nil {
|
||||
scene.ImageURL = imageGen.ImageURL
|
||||
scene.Status = "generated"
|
||||
if err := s.db.Save(&scene).Error; err != nil {
|
||||
s.log.Errorw("Failed to update scene image url", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Infow("Scene image generation created", "scene_id", req.SceneID, "image_gen_id", imageGen.ID)
|
||||
return imageGen, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("image generation service not available")
|
||||
}
|
||||
|
||||
func getStringValue(s *string) string {
|
||||
if s != nil {
|
||||
return *s
|
||||
}
|
||||
return ""
|
||||
}
|
||||
741
application/services/storyboard_service.go
Normal file
741
application/services/storyboard_service.go
Normal file
@@ -0,0 +1,741 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
models "github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"github.com/drama-generator/backend/pkg/utils"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type StoryboardService struct {
|
||||
db *gorm.DB
|
||||
aiService *AIService
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewStoryboardService(db *gorm.DB, log *logger.Logger) *StoryboardService {
|
||||
return &StoryboardService{
|
||||
db: db,
|
||||
aiService: NewAIService(db, log),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
type Storyboard struct {
|
||||
ShotNumber int `json:"shot_number"`
|
||||
Title string `json:"title"` // 镜头标题
|
||||
ShotType string `json:"shot_type"` // 景别
|
||||
Angle string `json:"angle"` // 镜头角度
|
||||
Time string `json:"time"` // 时间
|
||||
Location string `json:"location"` // 地点
|
||||
SceneID *uint `json:"scene_id"` // 背景ID(AI直接返回,可为null)
|
||||
Movement string `json:"movement"` // 运镜
|
||||
Action string `json:"action"` // 动作
|
||||
Dialogue string `json:"dialogue"` // 对话/独白
|
||||
Result string `json:"result"` // 画面结果
|
||||
Atmosphere string `json:"atmosphere"` // 环境氛围
|
||||
Emotion string `json:"emotion"` // 情绪
|
||||
Duration int `json:"duration"` // 时长(秒)
|
||||
BgmPrompt string `json:"bgm_prompt"` // 配乐提示词
|
||||
SoundEffect string `json:"sound_effect"` // 音效描述
|
||||
Characters []uint `json:"characters"` // 涉及的角色ID列表
|
||||
IsPrimary bool `json:"is_primary"` // 是否主镜
|
||||
}
|
||||
|
||||
type GenerateStoryboardResult struct {
|
||||
Storyboards []Storyboard `json:"storyboards"`
|
||||
Total int `json:"total"`
|
||||
}
|
||||
|
||||
func (s *StoryboardService) GenerateStoryboard(episodeID string) (*GenerateStoryboardResult, error) {
|
||||
// 从数据库获取剧集信息
|
||||
var episode struct {
|
||||
ID string
|
||||
ScriptContent *string
|
||||
Description *string
|
||||
DramaID string
|
||||
}
|
||||
|
||||
err := s.db.Table("episodes").
|
||||
Select("episodes.id, episodes.script_content, episodes.description, episodes.drama_id").
|
||||
Joins("INNER JOIN dramas ON dramas.id = episodes.drama_id").
|
||||
Where("episodes.id = ?", episodeID).
|
||||
First(&episode).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("剧集不存在或无权限访问")
|
||||
}
|
||||
|
||||
// 获取剧本内容
|
||||
var scriptContent string
|
||||
if episode.ScriptContent != nil && *episode.ScriptContent != "" {
|
||||
scriptContent = *episode.ScriptContent
|
||||
} else if episode.Description != nil && *episode.Description != "" {
|
||||
scriptContent = *episode.Description
|
||||
} else {
|
||||
return nil, fmt.Errorf("剧本内容为空,请先生成剧集内容")
|
||||
}
|
||||
|
||||
// 获取该剧本的所有角色
|
||||
var characters []models.Character
|
||||
if err := s.db.Where("drama_id = ?", episode.DramaID).Order("name ASC").Find(&characters).Error; err != nil {
|
||||
return nil, fmt.Errorf("获取角色列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建角色列表字符串(包含ID和名称)
|
||||
characterList := "无角色"
|
||||
if len(characters) > 0 {
|
||||
var charInfoList []string
|
||||
for _, char := range characters {
|
||||
charInfoList = append(charInfoList, fmt.Sprintf(`{"id": %d, "name": "%s"}`, char.ID, char.Name))
|
||||
}
|
||||
characterList = fmt.Sprintf("[%s]", strings.Join(charInfoList, ", "))
|
||||
}
|
||||
|
||||
// 获取该项目已提取的场景列表(项目级)
|
||||
var scenes []models.Scene
|
||||
if err := s.db.Where("drama_id = ?", episode.DramaID).Order("location ASC, time ASC").Find(&scenes).Error; err != nil {
|
||||
s.log.Warnw("Failed to get scenes", "error", err)
|
||||
}
|
||||
|
||||
// 构建场景列表字符串(包含ID、地点、时间)
|
||||
sceneList := "无场景"
|
||||
if len(scenes) > 0 {
|
||||
var sceneInfoList []string
|
||||
for _, bg := range scenes {
|
||||
sceneInfoList = append(sceneInfoList, fmt.Sprintf(`{"id": %d, "location": "%s", "time": "%s"}`, bg.ID, bg.Location, bg.Time))
|
||||
}
|
||||
sceneList = fmt.Sprintf("[%s]", strings.Join(sceneInfoList, ", "))
|
||||
}
|
||||
|
||||
s.log.Infow("Generating storyboard",
|
||||
"episode_id", episodeID,
|
||||
"drama_id", episode.DramaID,
|
||||
"script_length", len(scriptContent),
|
||||
"character_count", len(characters),
|
||||
"characters", characterList,
|
||||
"scene_count", len(scenes),
|
||||
"scenes", sceneList)
|
||||
|
||||
// 构建分镜头生成提示词
|
||||
prompt := fmt.Sprintf(`【角色】你是一位资深影视分镜师,精通罗伯特·麦基的镜头拆解理论,擅长构建情绪节奏。
|
||||
|
||||
【任务】将小说剧本按**独立动作单元**拆解为分镜头方案。
|
||||
|
||||
【本剧可用角色列表】
|
||||
%s
|
||||
|
||||
**重要**:在characters字段中,只能使用上述角色列表中的角色ID(数字),不得自创角色或使用其他ID。
|
||||
|
||||
【本剧已提取的场景背景列表】
|
||||
%s
|
||||
|
||||
**重要**:在scene_id字段中,必须从上述背景列表中选择最匹配的背景ID(数字)。如果没有合适的背景,则填null。
|
||||
|
||||
【剧本原文】
|
||||
%s
|
||||
|
||||
【分镜要素】每个镜头聚焦单一动作,描述要详尽具体:
|
||||
1. **镜头标题(title)**:用3-5个字概括该镜头的核心内容或情绪
|
||||
- 例如:"噩梦惊醒"、"对视沉思"、"逃离现场"、"意外发现"
|
||||
2. **时间**:[清晨/午后/深夜/具体时分+详细光线描述]
|
||||
- 例如:"深夜22:30·月光从破窗斜射入室内,形成明暗分界"
|
||||
3. **地点**:[场景完整描述+空间布局+环境细节]
|
||||
- 例如:"废弃码头仓库·锈蚀货架林立,地面积水反射微弱灯光,墙角堆放腐朽木箱"
|
||||
4. **镜头设计**:
|
||||
- **景别(shot_type)**:[远景/全景/中景/近景/特写]
|
||||
- **镜头角度(angle)**:[平视/仰视/俯视/侧面/背面]
|
||||
- **运镜方式(movement)**:[固定镜头/推镜/拉镜/摇镜/跟镜/移镜]
|
||||
5. **人物行为**:**详细动作描述**,包含[谁+具体怎么做+肢体细节+表情状态]
|
||||
- 例如:"陈峥弯腰用撬棍撬动保险箱门,手臂青筋暴起,眉头紧锁,汗水滑落脸颊"
|
||||
6. **对话/独白**:提取该镜头中的完整对话或独白内容(如无对话则为空字符串)
|
||||
7. **画面结果**:动作的即时后果+视觉细节+氛围变化
|
||||
- 例如:"保险箱门弹开发出金属碰撞声,扬起灰尘在光束中飘散,箱内空无一物只有陈旧报纸,陈峥表情从期待转为失望"
|
||||
8. **环境氛围**:光线质感+色调+声音环境+整体氛围
|
||||
- 例如:"昏暗冷色调,只有手电筒光束晃动,远处传来海浪拍打声,压抑沉闷"
|
||||
9. **配乐提示(bgm_prompt)**:描述该镜头配乐的氛围、节奏、情绪(如无特殊要求则为空字符串)
|
||||
- 例如:"低沉紧张的弦乐,节奏缓慢,营造压抑氛围"
|
||||
10. **音效描述(sound_effect)**:描述该镜头的关键音效(如无特殊音效则为空字符串)
|
||||
- 例如:"金属碰撞声、脚步声、海浪拍打声"
|
||||
11. **观众情绪**:[情绪类型]([强度:↑↑↑/↑↑/↑/→/↓] + [落点:悬置/释放/反转])
|
||||
|
||||
【输出格式】请以JSON格式输出,每个镜头包含以下字段(**所有描述性字段都要详细完整**):
|
||||
{
|
||||
"storyboards": [
|
||||
{
|
||||
"shot_number": 1,
|
||||
"title": "噩梦惊醒",
|
||||
"shot_type": "全景",
|
||||
"angle": "俯视45度角",
|
||||
"time": "深夜22:30·月光从破窗斜射入仓库,在地面积水中形成银白色反光,墙角昏暗不清",
|
||||
"location": "废弃码头仓库·锈蚀货架林立,地面积水反射微弱灯光,墙角堆放腐朽木箱和渔网,空气中弥漫潮湿霉味",
|
||||
"scene_id": 1,
|
||||
"movement": "固定镜头",
|
||||
"action": "陈峥弯腰双手握住撬棍用力撬动保险箱门,手臂青筋暴起,眉头紧锁,汗水从额头滑落脸颊,呼吸急促",
|
||||
"dialogue": "(独白)这么多年了,里面到底藏着什么秘密?",
|
||||
"result": "保险箱门突然弹开发出刺耳金属声,扬起灰尘在手电筒光束中飘散,箱内空无一物只有几张发黄的旧报纸,陈峥表情从期待转为震惊和失望,瞳孔放大",
|
||||
"atmosphere": "昏暗冷色调·青灰色为主,只有手电筒光束在黑暗中晃动,远处传来海浪拍打码头的沉闷声,整体氛围压抑沉重",
|
||||
"emotion": "好奇感↑↑转失望↓(情绪反转)",
|
||||
"duration": 9,
|
||||
"bgm_prompt": "低沉紧张的弦乐,节奏缓慢,营造压抑悬疑氛围",
|
||||
"sound_effect": "金属碰撞声、灰尘飘散声、海浪拍打声",
|
||||
"characters": [159],
|
||||
"is_primary": true
|
||||
},
|
||||
{
|
||||
"shot_number": 2,
|
||||
"title": "对视沉思",
|
||||
"shot_type": "近景",
|
||||
"angle": "平视",
|
||||
"time": "深夜22:31·仓库内光线昏暗,只有手电筒光从侧面照亮两人脸部轮廓",
|
||||
"location": "废弃码头仓库·保险箱旁,背景是模糊的货架剪影",
|
||||
"scene_id": 1,
|
||||
"movement": "推镜",
|
||||
"action": "陈峥缓缓转身,目光与身后的李芳对视,李芳手握手电筒,光束在两人之间晃动,眼神中透露疑惑和警惕",
|
||||
"dialogue": "陈峥:\"我们被耍了,这里根本没有我们要找的东西。\" 李芳:\"现在怎么办?我们的时间不多了。\"",
|
||||
"result": "两人站在昏暗中陷入沉思,手电筒光束照在地面形成圆形光斑,背景传来微弱的金属摩擦声,气氛紧张凝重",
|
||||
"atmosphere": "低调光线·暗部占画面70%,侧面硬光勾勒人物轮廓,冷暖光对比强烈,海风吹过产生呼啸声,营造紧迫感",
|
||||
"emotion": "紧张感↑↑·警惕↑↑(悬置)",
|
||||
"duration": 7,
|
||||
"bgm_prompt": "紧张感逐渐升级的音效,低频持续音",
|
||||
"sound_effect": "呼吸声、金属摩擦声、海风呼啸声",
|
||||
"characters": [159, 160],
|
||||
"is_primary": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**dialogue字段说明**:
|
||||
- 如果有对话,格式为:角色名:\"台词内容\"
|
||||
- 多人对话用空格分隔:角色A:\"...\" 角色B:\"...\"
|
||||
- 独白格式为:(独白)内容
|
||||
- 旁白格式为:(旁白)内容
|
||||
- 无对话时填写空字符串:""
|
||||
- **对话内容必须从原剧本中提取,保持原汁原味**
|
||||
|
||||
**角色和背景要求**:
|
||||
- characters字段必须包含该镜头中出现的所有角色ID(数字数组格式)
|
||||
- 只提取实际出现的角色ID,不出现角色则为空数组[]
|
||||
- **角色ID必须严格使用【本剧可用角色列表】中的id字段(数字),不得使用其他ID或自创角色**
|
||||
- 例如:如果镜头中出现李明(id:159)和王芳(id:160),则characters字段应为[159, 160]
|
||||
- scene_id字段必须从【本剧已提取的场景背景列表】中选择最匹配的背景ID(数字)
|
||||
- 如果列表中没有合适的背景,则scene_id填null
|
||||
- 例如:如果镜头发生在"城市公寓卧室·凌晨",应选择id为1的场景背景
|
||||
|
||||
**duration时长估算规则(秒)**:
|
||||
- **所有镜头时长必须在4-12秒范围内**,确保节奏合理流畅
|
||||
- **综合估算原则**:时长由对话内容、动作复杂度、情绪节奏三方面综合决定
|
||||
|
||||
**估算步骤**:
|
||||
1. **基础时长**(从场景内容判断):
|
||||
- 纯对话场景(无明显动作):基础4秒
|
||||
- 纯动作场景(无对话):基础5秒
|
||||
- 对话+动作混合场景:基础6秒
|
||||
|
||||
2. **对话调整**(根据台词字数增加时长):
|
||||
- 无对话:+0秒
|
||||
- 短对话(1-20字):+1-2秒
|
||||
- 中等对话(21-50字):+2-4秒
|
||||
- 长对话(51字以上):+4-6秒
|
||||
|
||||
3. **动作调整**(根据动作复杂度增加时长):
|
||||
- 无动作/静态:+0秒
|
||||
- 简单动作(表情、转身、拿物品):+0-1秒
|
||||
- 一般动作(走动、开门、坐下):+1-2秒
|
||||
- 复杂动作(打斗、追逐、大幅度移动):+2-4秒
|
||||
- 环境展示(全景扫描、氛围营造):+2-5秒
|
||||
|
||||
4. **最终时长** = 基础时长 + 对话调整 + 动作调整,确保结果在4-12秒范围内
|
||||
|
||||
**示例**:
|
||||
- "陈峥转身离开"(简单动作,无对话):5 + 0 + 1 = 6秒
|
||||
- "李芳:\"你要去哪里?\""(短对话,无动作):4 + 2 + 0 = 6秒
|
||||
- "陈峥推开房门,李芳:\"终于找到你了,这些年你去哪了?\""(一般动作+中等对话):6 + 3 + 2 = 11秒
|
||||
- "两人在雨中激烈搏斗,陈峥:\"住手!\""(复杂动作+短对话):6 + 2 + 4 = 12秒
|
||||
|
||||
**重要**:准确估算每个镜头时长,所有分镜时长之和将作为剧集总时长
|
||||
|
||||
**特别要求**:
|
||||
- **【极其重要】必须100%%完整拆解整个剧本,不得省略、跳过、压缩任何剧情内容**
|
||||
- **从剧本第一个字到最后一个字,逐句逐段转换为分镜**
|
||||
- **每个对话、每个动作、每个场景转换都必须有对应的分镜**
|
||||
- 剧本越长,分镜数量越多(短剧本15-30个,中等剧本30-60个,长剧本60-100个甚至更多)
|
||||
- **宁可分镜多,也不要遗漏剧情**:一个长场景可拆分为多个连续分镜
|
||||
- 每个镜头只描述一个主要动作
|
||||
- 区分主镜(is_primary: true)和链接镜(is_primary: false)
|
||||
- 确保情绪节奏有变化
|
||||
- **duration字段至关重要**:准确估算每个镜头时长,这将用于计算整集时长
|
||||
- 严格按照JSON格式输出
|
||||
|
||||
**【禁止行为】**:
|
||||
- ❌ 禁止用一个镜头概括多个场景
|
||||
- ❌ 禁止跳过任何对话或独白
|
||||
- ❌ 禁止省略剧情发展过程
|
||||
- ❌ 禁止合并本应分开的镜头
|
||||
- ✅ 正确做法:剧本有多少内容,就拆解出对应数量的分镜,确保观众看完所有分镜能完整了解剧情
|
||||
|
||||
**【关键】场景描述详细度要求**(这些描述将直接用于视频生成模型):
|
||||
1. **时间(time)字段**:必须包含≥15字的详细描述
|
||||
- ✓ 好例子:"深夜22:30·月光从破窗斜射入仓库,在地面积水中形成银白色反光,墙角昏暗不清"
|
||||
- ✗ 差例子:"深夜"
|
||||
|
||||
2. **地点(location)字段**:必须包含≥20字的详细场景描述
|
||||
- ✓ 好例子:"废弃码头仓库·锈蚀货架林立,地面积水反射微弱灯光,墙角堆放腐朽木箱和渔网,空气中弥漫潮湿霉味"
|
||||
- ✗ 差例子:"仓库"
|
||||
|
||||
3. **动作(action)字段**:必须包含≥25字的详细动作描述,包括肢体细节和表情
|
||||
- ✓ 好例子:"陈峥弯腰双手握住撬棍用力撬动保险箱门,手臂青筋暴起,眉头紧锁,汗水从额头滑落脸颊,呼吸急促"
|
||||
- ✗ 差例子:"陈峥打开保险箱"
|
||||
|
||||
4. **结果(result)字段**:必须包含≥25字的详细视觉结果描述
|
||||
- ✓ 好例子:"保险箱门突然弹开发出刺耳金属声,扬起灰尘在手电筒光束中飘散,箱内空无一物只有几张发黄的旧报纸,陈峥表情从期待转为震惊和失望,瞳孔放大"
|
||||
- ✗ 差例子:"门打开了"
|
||||
|
||||
5. **氛围(atmosphere)字段**:必须包含≥20字的环境氛围描述,包括光线、色调、声音
|
||||
- ✓ 好例子:"昏暗冷色调·青灰色为主,只有手电筒光束在黑暗中晃动,远处传来海浪拍打码头的沉闷声,整体氛围压抑沉重"
|
||||
- ✗ 差例子:"昏暗"
|
||||
|
||||
**描述原则**:
|
||||
- 所有描述性字段要像为盲人讲述画面一样详细
|
||||
- 包含感官细节:视觉、听觉、触觉、嗅觉
|
||||
- 描述光线、色彩、质感、动态
|
||||
- 为视频生成AI提供足够的画面构建信息
|
||||
- 避免抽象词汇,使用具象的视觉化描述`, characterList, sceneList, scriptContent)
|
||||
|
||||
// 调用AI服务生成
|
||||
text, err := s.aiService.GenerateText(prompt, "")
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to generate storyboard", "error", err)
|
||||
return nil, fmt.Errorf("生成分镜头失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析JSON结果
|
||||
var result GenerateStoryboardResult
|
||||
if err := utils.SafeParseAIJSON(text, &result); err != nil {
|
||||
s.log.Errorw("Failed to parse storyboard JSON", "error", err, "response", text[:min(500, len(text))])
|
||||
return nil, fmt.Errorf("解析分镜头结果失败: %w", err)
|
||||
}
|
||||
|
||||
result.Total = len(result.Storyboards)
|
||||
|
||||
// 计算总时长(所有分镜时长之和)
|
||||
totalDuration := 0
|
||||
for _, sb := range result.Storyboards {
|
||||
totalDuration += sb.Duration
|
||||
}
|
||||
|
||||
s.log.Infow("Storyboard generated",
|
||||
"episode_id", episodeID,
|
||||
"count", result.Total,
|
||||
"total_duration_seconds", totalDuration)
|
||||
|
||||
// 保存分镜头到数据库
|
||||
if err := s.saveStoryboards(episodeID, result.Storyboards); err != nil {
|
||||
s.log.Errorw("Failed to save storyboards", "error", err)
|
||||
return nil, fmt.Errorf("保存分镜头失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新剧集时长(秒转分钟,向上取整)
|
||||
durationMinutes := (totalDuration + 59) / 60
|
||||
if err := s.db.Model(&models.Episode{}).Where("id = ?", episodeID).Update("duration", durationMinutes).Error; err != nil {
|
||||
s.log.Errorw("Failed to update episode duration", "error", err)
|
||||
// 不中断流程,只记录错误
|
||||
} else {
|
||||
s.log.Infow("Episode duration updated",
|
||||
"episode_id", episodeID,
|
||||
"duration_seconds", totalDuration,
|
||||
"duration_minutes", durationMinutes)
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// generateImagePrompt 生成专门用于图片生成的提示词(首帧静态画面)
|
||||
func (s *StoryboardService) generateImagePrompt(sb Storyboard) string {
|
||||
var parts []string
|
||||
|
||||
// 1. 完整的场景背景描述
|
||||
if sb.Location != "" {
|
||||
locationDesc := sb.Location
|
||||
if sb.Time != "" {
|
||||
locationDesc += ", " + sb.Time
|
||||
}
|
||||
parts = append(parts, locationDesc)
|
||||
}
|
||||
|
||||
// 2. 角色初始静态姿态(去除动作过程,只保留起始状态)
|
||||
if sb.Action != "" {
|
||||
initialPose := extractInitialPose(sb.Action)
|
||||
if initialPose != "" {
|
||||
parts = append(parts, initialPose)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 情绪氛围
|
||||
if sb.Emotion != "" {
|
||||
parts = append(parts, sb.Emotion)
|
||||
}
|
||||
|
||||
// 4. 动漫风格
|
||||
parts = append(parts, "anime style, first frame")
|
||||
|
||||
if len(parts) > 0 {
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
return "anime scene"
|
||||
}
|
||||
|
||||
// extractInitialPose 提取初始静态姿态(去除动作过程)
|
||||
func extractInitialPose(action string) string {
|
||||
// 去除动作过程关键词,保留初始状态描述
|
||||
processWords := []string{
|
||||
"然后", "接着", "接下来", "随后", "紧接着",
|
||||
"向下", "向上", "向前", "向后", "向左", "向右",
|
||||
"开始", "继续", "逐渐", "慢慢", "快速", "突然", "猛然",
|
||||
}
|
||||
|
||||
result := action
|
||||
for _, word := range processWords {
|
||||
if idx := strings.Index(result, word); idx > 0 {
|
||||
// 在动作过程词之前截断
|
||||
result = result[:idx]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 清理末尾标点
|
||||
result = strings.TrimRight(result, ",。,. ")
|
||||
return strings.TrimSpace(result)
|
||||
}
|
||||
|
||||
// extractSimpleLocation 提取简化的场景地点(去除详细描述)
|
||||
func extractSimpleLocation(location string) string {
|
||||
// 在"·"符号处截断,只保留主场景名称
|
||||
if idx := strings.Index(location, "·"); idx > 0 {
|
||||
return strings.TrimSpace(location[:idx])
|
||||
}
|
||||
|
||||
// 如果有逗号,只保留第一部分
|
||||
if idx := strings.Index(location, ","); idx > 0 {
|
||||
return strings.TrimSpace(location[:idx])
|
||||
}
|
||||
if idx := strings.Index(location, ","); idx > 0 {
|
||||
return strings.TrimSpace(location[:idx])
|
||||
}
|
||||
|
||||
// 限制长度不超过15个字符
|
||||
maxLen := 15
|
||||
if len(location) > maxLen {
|
||||
return strings.TrimSpace(location[:maxLen])
|
||||
}
|
||||
|
||||
return strings.TrimSpace(location)
|
||||
}
|
||||
|
||||
// extractSimplePose 提取简单的核心姿态关键词(不超过10个字)
|
||||
func extractSimplePose(action string) string {
|
||||
// 只提取前面最多10个字符作为核心姿态
|
||||
runes := []rune(action)
|
||||
maxLen := 10
|
||||
if len(runes) > maxLen {
|
||||
// 在标点符号处截断
|
||||
truncated := runes[:maxLen]
|
||||
for i := maxLen - 1; i >= 0; i-- {
|
||||
if truncated[i] == ',' || truncated[i] == '。' || truncated[i] == ',' || truncated[i] == '.' {
|
||||
truncated = runes[:i]
|
||||
break
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(string(truncated))
|
||||
}
|
||||
return strings.TrimSpace(action)
|
||||
}
|
||||
|
||||
// extractFirstFramePose 从动作描述中提取首帧静态姿态
|
||||
func extractFirstFramePose(action string) string {
|
||||
// 去除表示动作过程的关键词,保留初始状态
|
||||
processWords := []string{
|
||||
"然后", "接着", "向下", "向前", "走向", "冲向", "转身",
|
||||
"开始", "继续", "逐渐", "慢慢", "快速", "突然",
|
||||
}
|
||||
|
||||
pose := action
|
||||
for _, word := range processWords {
|
||||
// 简单处理:在这些词之前截断
|
||||
if idx := strings.Index(pose, word); idx > 0 {
|
||||
pose = pose[:idx]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 清理末尾标点
|
||||
pose = strings.TrimRight(pose, ",。,.")
|
||||
return strings.TrimSpace(pose)
|
||||
}
|
||||
|
||||
// extractCompositionType 从镜头类型中提取构图类型(去除运镜)
|
||||
func extractCompositionType(shotType string) string {
|
||||
// 去除运镜相关描述
|
||||
cameraMovements := []string{
|
||||
"晃动", "摇晃", "推进", "拉远", "跟随", "环绕",
|
||||
"运镜", "摄影", "移动", "旋转",
|
||||
}
|
||||
|
||||
comp := shotType
|
||||
for _, movement := range cameraMovements {
|
||||
comp = strings.ReplaceAll(comp, movement, "")
|
||||
}
|
||||
|
||||
// 清理多余的标点和空格
|
||||
comp = strings.ReplaceAll(comp, "··", "·")
|
||||
comp = strings.ReplaceAll(comp, "·", " ")
|
||||
comp = strings.TrimSpace(comp)
|
||||
|
||||
return comp
|
||||
}
|
||||
|
||||
// generateVideoPrompt 生成专门用于视频生成的提示词(包含运镜和动态元素)
|
||||
func (s *StoryboardService) generateVideoPrompt(sb Storyboard) string {
|
||||
var parts []string
|
||||
|
||||
// 1. 人物动作
|
||||
if sb.Action != "" {
|
||||
parts = append(parts, fmt.Sprintf("Action: %s", sb.Action))
|
||||
}
|
||||
|
||||
// 2. 对话
|
||||
if sb.Dialogue != "" {
|
||||
parts = append(parts, fmt.Sprintf("Dialogue: %s", sb.Dialogue))
|
||||
}
|
||||
|
||||
// 3. 镜头运动(视频特有)
|
||||
if sb.Movement != "" {
|
||||
parts = append(parts, fmt.Sprintf("Camera movement: %s", sb.Movement))
|
||||
}
|
||||
|
||||
// 4. 镜头类型和角度
|
||||
if sb.ShotType != "" {
|
||||
parts = append(parts, fmt.Sprintf("Shot type: %s", sb.ShotType))
|
||||
}
|
||||
if sb.Angle != "" {
|
||||
parts = append(parts, fmt.Sprintf("Camera angle: %s", sb.Angle))
|
||||
}
|
||||
|
||||
// 5. 场景环境
|
||||
if sb.Location != "" {
|
||||
locationDesc := sb.Location
|
||||
if sb.Time != "" {
|
||||
locationDesc += ", " + sb.Time
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("Scene: %s", locationDesc))
|
||||
}
|
||||
|
||||
// 6. 环境氛围
|
||||
if sb.Atmosphere != "" {
|
||||
parts = append(parts, fmt.Sprintf("Atmosphere: %s", sb.Atmosphere))
|
||||
}
|
||||
|
||||
// 7. 情绪和结果
|
||||
if sb.Emotion != "" {
|
||||
parts = append(parts, fmt.Sprintf("Mood: %s", sb.Emotion))
|
||||
}
|
||||
if sb.Result != "" {
|
||||
parts = append(parts, fmt.Sprintf("Result: %s", sb.Result))
|
||||
}
|
||||
|
||||
// 8. 音频元素
|
||||
if sb.BgmPrompt != "" {
|
||||
parts = append(parts, fmt.Sprintf("BGM: %s", sb.BgmPrompt))
|
||||
}
|
||||
if sb.SoundEffect != "" {
|
||||
parts = append(parts, fmt.Sprintf("Sound effects: %s", sb.SoundEffect))
|
||||
}
|
||||
|
||||
// 9. 视频风格要求
|
||||
parts = append(parts, "Style: cinematic anime style, smooth camera motion, natural character movement")
|
||||
|
||||
if len(parts) > 0 {
|
||||
return strings.Join(parts, ". ")
|
||||
}
|
||||
return "Anime style video scene"
|
||||
}
|
||||
|
||||
func (s *StoryboardService) saveStoryboards(episodeID string, storyboards []Storyboard) error {
|
||||
// 开启事务
|
||||
return s.db.Transaction(func(tx *gorm.DB) error {
|
||||
// 获取该剧集所有的分镜ID
|
||||
var storyboardIDs []uint
|
||||
if err := tx.Model(&models.Storyboard{}).
|
||||
Where("episode_id = ?", episodeID).
|
||||
Pluck("id", &storyboardIDs).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果有分镜,先清理关联的image_generations的storyboard_id
|
||||
if len(storyboardIDs) > 0 {
|
||||
if err := tx.Model(&models.ImageGeneration{}).
|
||||
Where("storyboard_id IN ?", storyboardIDs).
|
||||
Update("storyboard_id", nil).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 删除该剧集已有的分镜头
|
||||
if err := tx.Where("episode_id = ?", episodeID).Delete(&models.Storyboard{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 注意:不删除背景,因为背景是在分镜拆解前就提取好的
|
||||
// AI会直接返回scene_id,不需要在这里做字符串匹配
|
||||
|
||||
// 保存新的分镜头
|
||||
for _, sb := range storyboards {
|
||||
// 构建描述信息,包含对话
|
||||
description := fmt.Sprintf("【镜头类型】%s\n【运镜】%s\n【动作】%s\n【对话】%s\n【结果】%s\n【情绪】%s",
|
||||
sb.ShotType, sb.Movement, sb.Action, sb.Dialogue, sb.Result, sb.Emotion)
|
||||
|
||||
// 生成两种专用提示词
|
||||
imagePrompt := s.generateImagePrompt(sb) // 专用于图片生成
|
||||
videoPrompt := s.generateVideoPrompt(sb) // 专用于视频生成
|
||||
|
||||
// 处理 dialogue 字段
|
||||
var dialoguePtr *string
|
||||
if sb.Dialogue != "" {
|
||||
dialoguePtr = &sb.Dialogue
|
||||
}
|
||||
|
||||
// 使用AI直接返回的SceneID
|
||||
if sb.SceneID != nil {
|
||||
s.log.Infow("Background ID from AI",
|
||||
"shot_number", sb.ShotNumber,
|
||||
"scene_id", *sb.SceneID)
|
||||
}
|
||||
|
||||
epID, _ := strconv.ParseUint(episodeID, 10, 32)
|
||||
|
||||
// 处理 title 字段
|
||||
var titlePtr *string
|
||||
if sb.Title != "" {
|
||||
titlePtr = &sb.Title
|
||||
}
|
||||
|
||||
// 处理shot_type、angle、movement字段
|
||||
var shotTypePtr, anglePtr, movementPtr *string
|
||||
if sb.ShotType != "" {
|
||||
shotTypePtr = &sb.ShotType
|
||||
}
|
||||
if sb.Angle != "" {
|
||||
anglePtr = &sb.Angle
|
||||
}
|
||||
if sb.Movement != "" {
|
||||
movementPtr = &sb.Movement
|
||||
}
|
||||
|
||||
// 处理bgm_prompt、sound_effect字段
|
||||
var bgmPromptPtr, soundEffectPtr *string
|
||||
if sb.BgmPrompt != "" {
|
||||
bgmPromptPtr = &sb.BgmPrompt
|
||||
}
|
||||
if sb.SoundEffect != "" {
|
||||
soundEffectPtr = &sb.SoundEffect
|
||||
}
|
||||
|
||||
// 处理result、atmosphere字段
|
||||
var resultPtr, atmospherePtr *string
|
||||
if sb.Result != "" {
|
||||
resultPtr = &sb.Result
|
||||
}
|
||||
if sb.Atmosphere != "" {
|
||||
atmospherePtr = &sb.Atmosphere
|
||||
}
|
||||
|
||||
scene := models.Storyboard{
|
||||
EpisodeID: uint(epID),
|
||||
SceneID: sb.SceneID,
|
||||
StoryboardNumber: sb.ShotNumber,
|
||||
Title: titlePtr,
|
||||
Location: &sb.Location,
|
||||
Time: &sb.Time,
|
||||
ShotType: shotTypePtr,
|
||||
Angle: anglePtr,
|
||||
Movement: movementPtr,
|
||||
Description: &description,
|
||||
Action: &sb.Action,
|
||||
Result: resultPtr,
|
||||
Atmosphere: atmospherePtr,
|
||||
Dialogue: dialoguePtr,
|
||||
ImagePrompt: &imagePrompt,
|
||||
VideoPrompt: &videoPrompt,
|
||||
BgmPrompt: bgmPromptPtr,
|
||||
SoundEffect: soundEffectPtr,
|
||||
Duration: sb.Duration,
|
||||
}
|
||||
|
||||
if err := tx.Create(&scene).Error; err != nil {
|
||||
s.log.Errorw("Failed to create scene", "error", err, "shot_number", sb.ShotNumber)
|
||||
return err
|
||||
}
|
||||
|
||||
// 关联角色
|
||||
if len(sb.Characters) > 0 {
|
||||
var characters []models.Character
|
||||
if err := tx.Where("id IN ?", sb.Characters).Find(&characters).Error; err != nil {
|
||||
s.log.Warnw("Failed to load characters for association", "error", err, "character_ids", sb.Characters)
|
||||
} else if len(characters) > 0 {
|
||||
if err := tx.Model(&scene).Association("Characters").Append(characters); err != nil {
|
||||
s.log.Warnw("Failed to associate characters", "error", err, "shot_number", sb.ShotNumber)
|
||||
} else {
|
||||
s.log.Infow("Characters associated successfully",
|
||||
"shot_number", sb.ShotNumber,
|
||||
"character_ids", sb.Characters,
|
||||
"count", len(characters))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Infow("Storyboards saved successfully", "episode_id", episodeID, "count", len(storyboards))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateStoryboardCharacters 更新分镜的角色关联
|
||||
func (s *StoryboardService) UpdateStoryboardCharacters(storyboardID string, characterIDs []uint) error {
|
||||
// 查找分镜
|
||||
var storyboard models.Storyboard
|
||||
if err := s.db.First(&storyboard, storyboardID).Error; err != nil {
|
||||
return fmt.Errorf("storyboard not found: %w", err)
|
||||
}
|
||||
|
||||
// 清除现有的角色关联
|
||||
if err := s.db.Model(&storyboard).Association("Characters").Clear(); err != nil {
|
||||
return fmt.Errorf("failed to clear characters: %w", err)
|
||||
}
|
||||
|
||||
// 如果有新的角色ID,加载并关联
|
||||
if len(characterIDs) > 0 {
|
||||
var characters []models.Character
|
||||
if err := s.db.Where("id IN ?", characterIDs).Find(&characters).Error; err != nil {
|
||||
return fmt.Errorf("failed to find characters: %w", err)
|
||||
}
|
||||
|
||||
if err := s.db.Model(&storyboard).Association("Characters").Append(characters); err != nil {
|
||||
return fmt.Errorf("failed to associate characters: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Infow("Storyboard characters updated", "storyboard_id", storyboardID, "character_count", len(characterIDs))
|
||||
return nil
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
138
application/services/storyboard_update_full.go
Normal file
138
application/services/storyboard_update_full.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/drama-generator/backend/domain/models"
|
||||
)
|
||||
|
||||
// UpdateStoryboard 更新分镜的所有字段,并重新生成提示词
|
||||
func (s *StoryboardService) UpdateStoryboard(storyboardID string, updates map[string]interface{}) error {
|
||||
// 查找分镜
|
||||
var storyboard models.Storyboard
|
||||
if err := s.db.First(&storyboard, storyboardID).Error; err != nil {
|
||||
return fmt.Errorf("storyboard not found: %w", err)
|
||||
}
|
||||
|
||||
// 构建用于重新生成提示词的Storyboard结构
|
||||
sb := Storyboard{
|
||||
ShotNumber: storyboard.StoryboardNumber,
|
||||
}
|
||||
|
||||
// 从updates中提取字段并更新
|
||||
updateData := make(map[string]interface{})
|
||||
|
||||
if val, ok := updates["title"].(string); ok && val != "" {
|
||||
updateData["title"] = val
|
||||
sb.Title = val
|
||||
}
|
||||
if val, ok := updates["shot_type"].(string); ok && val != "" {
|
||||
updateData["shot_type"] = val
|
||||
sb.ShotType = val
|
||||
}
|
||||
if val, ok := updates["angle"].(string); ok && val != "" {
|
||||
updateData["angle"] = val
|
||||
sb.Angle = val
|
||||
}
|
||||
if val, ok := updates["movement"].(string); ok && val != "" {
|
||||
updateData["movement"] = val
|
||||
sb.Movement = val
|
||||
}
|
||||
if val, ok := updates["location"].(string); ok && val != "" {
|
||||
updateData["location"] = val
|
||||
sb.Location = val
|
||||
}
|
||||
if val, ok := updates["time"].(string); ok && val != "" {
|
||||
updateData["time"] = val
|
||||
sb.Time = val
|
||||
}
|
||||
if val, ok := updates["action"].(string); ok && val != "" {
|
||||
updateData["action"] = val
|
||||
sb.Action = val
|
||||
}
|
||||
if val, ok := updates["dialogue"].(string); ok && val != "" {
|
||||
updateData["dialogue"] = val
|
||||
sb.Dialogue = val
|
||||
}
|
||||
if val, ok := updates["result"].(string); ok && val != "" {
|
||||
updateData["result"] = val
|
||||
sb.Result = val
|
||||
}
|
||||
if val, ok := updates["atmosphere"].(string); ok && val != "" {
|
||||
updateData["atmosphere"] = val
|
||||
sb.Atmosphere = val
|
||||
}
|
||||
if val, ok := updates["description"].(string); ok && val != "" {
|
||||
updateData["description"] = val
|
||||
}
|
||||
if val, ok := updates["bgm_prompt"].(string); ok && val != "" {
|
||||
updateData["bgm_prompt"] = val
|
||||
sb.BgmPrompt = val
|
||||
}
|
||||
if val, ok := updates["sound_effect"].(string); ok && val != "" {
|
||||
updateData["sound_effect"] = val
|
||||
sb.SoundEffect = val
|
||||
}
|
||||
if val, ok := updates["duration"].(float64); ok {
|
||||
updateData["duration"] = int(val)
|
||||
sb.Duration = int(val)
|
||||
}
|
||||
|
||||
// 使用当前数据库值填充缺失字段(用于生成提示词)
|
||||
if sb.Title == "" && storyboard.Title != nil {
|
||||
sb.Title = *storyboard.Title
|
||||
}
|
||||
if sb.ShotType == "" && storyboard.ShotType != nil {
|
||||
sb.ShotType = *storyboard.ShotType
|
||||
}
|
||||
if sb.Angle == "" && storyboard.Angle != nil {
|
||||
sb.Angle = *storyboard.Angle
|
||||
}
|
||||
if sb.Movement == "" && storyboard.Movement != nil {
|
||||
sb.Movement = *storyboard.Movement
|
||||
}
|
||||
if sb.Location == "" && storyboard.Location != nil {
|
||||
sb.Location = *storyboard.Location
|
||||
}
|
||||
if sb.Time == "" && storyboard.Time != nil {
|
||||
sb.Time = *storyboard.Time
|
||||
}
|
||||
if sb.Action == "" && storyboard.Action != nil {
|
||||
sb.Action = *storyboard.Action
|
||||
}
|
||||
if sb.Dialogue == "" && storyboard.Dialogue != nil {
|
||||
sb.Dialogue = *storyboard.Dialogue
|
||||
}
|
||||
if sb.Result == "" && storyboard.Result != nil {
|
||||
sb.Result = *storyboard.Result
|
||||
}
|
||||
if sb.Atmosphere == "" && storyboard.Atmosphere != nil {
|
||||
sb.Atmosphere = *storyboard.Atmosphere
|
||||
}
|
||||
if sb.BgmPrompt == "" && storyboard.BgmPrompt != nil {
|
||||
sb.BgmPrompt = *storyboard.BgmPrompt
|
||||
}
|
||||
if sb.SoundEffect == "" && storyboard.SoundEffect != nil {
|
||||
sb.SoundEffect = *storyboard.SoundEffect
|
||||
}
|
||||
if sb.Duration == 0 {
|
||||
sb.Duration = storyboard.Duration
|
||||
}
|
||||
|
||||
// 只重新生成video_prompt
|
||||
// image_prompt不自动更新,因为可能对应多张已生成的帧图片
|
||||
videoPrompt := s.generateVideoPrompt(sb)
|
||||
|
||||
updateData["video_prompt"] = videoPrompt
|
||||
|
||||
// 更新数据库
|
||||
if err := s.db.Model(&storyboard).Updates(updateData).Error; err != nil {
|
||||
return fmt.Errorf("failed to update storyboard: %w", err)
|
||||
}
|
||||
|
||||
s.log.Infow("Storyboard updated successfully",
|
||||
"storyboard_id", storyboardID,
|
||||
"fields_updated", len(updateData))
|
||||
|
||||
return nil
|
||||
}
|
||||
113
application/services/task_service.go
Normal file
113
application/services/task_service.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type TaskService struct {
|
||||
db *gorm.DB
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewTaskService(db *gorm.DB, log *logger.Logger) *TaskService {
|
||||
return &TaskService{
|
||||
db: db,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTask 创建新任务
|
||||
func (s *TaskService) CreateTask(taskType, resourceID string) (*models.AsyncTask, error) {
|
||||
task := &models.AsyncTask{
|
||||
ID: uuid.New().String(),
|
||||
Type: taskType,
|
||||
Status: "pending",
|
||||
Progress: 0,
|
||||
ResourceID: resourceID,
|
||||
}
|
||||
|
||||
if err := s.db.Create(task).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create task: %w", err)
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// UpdateTaskStatus 更新任务状态
|
||||
func (s *TaskService) UpdateTaskStatus(taskID, status string, progress int, message string) error {
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
if status == "completed" || status == "failed" {
|
||||
now := time.Now()
|
||||
updates["completed_at"] = &now
|
||||
}
|
||||
|
||||
return s.db.Model(&models.AsyncTask{}).
|
||||
Where("id = ?", taskID).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateTaskError 更新任务错误
|
||||
func (s *TaskService) UpdateTaskError(taskID string, err error) error {
|
||||
now := time.Now()
|
||||
return s.db.Model(&models.AsyncTask{}).
|
||||
Where("id = ?", taskID).
|
||||
Updates(map[string]interface{}{
|
||||
"status": "failed",
|
||||
"error": err.Error(),
|
||||
"progress": 0,
|
||||
"completed_at": &now,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateTaskResult 更新任务结果
|
||||
func (s *TaskService) UpdateTaskResult(taskID string, result interface{}) error {
|
||||
resultJSON, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal result: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
return s.db.Model(&models.AsyncTask{}).
|
||||
Where("id = ?", taskID).
|
||||
Updates(map[string]interface{}{
|
||||
"status": "completed",
|
||||
"progress": 100,
|
||||
"result": string(resultJSON),
|
||||
"completed_at": &now,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// GetTask 获取任务信息
|
||||
func (s *TaskService) GetTask(taskID string) (*models.AsyncTask, error) {
|
||||
var task models.AsyncTask
|
||||
if err := s.db.Where("id = ?", taskID).First(&task).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// GetTasksByResource 获取资源相关的所有任务
|
||||
func (s *TaskService) GetTasksByResource(resourceID string) ([]*models.AsyncTask, error) {
|
||||
var tasks []*models.AsyncTask
|
||||
if err := s.db.Where("resource_id = ?", resourceID).
|
||||
Order("created_at DESC").
|
||||
Find(&tasks).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tasks, nil
|
||||
}
|
||||
109
application/services/upload_service.go
Normal file
109
application/services/upload_service.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/drama-generator/backend/pkg/config"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type UploadService struct {
|
||||
storagePath string
|
||||
baseURL string
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewUploadService(cfg *config.Config, log *logger.Logger) (*UploadService, error) {
|
||||
// 确保存储目录存在
|
||||
if err := os.MkdirAll(cfg.Storage.LocalPath, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create storage directory: %w", err)
|
||||
}
|
||||
|
||||
return &UploadService{
|
||||
storagePath: cfg.Storage.LocalPath,
|
||||
baseURL: cfg.Storage.BaseURL,
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UploadFile 上传文件到本地存储
|
||||
func (s *UploadService) UploadFile(file io.Reader, fileName, contentType string, category string) (string, error) {
|
||||
// 创建分类目录
|
||||
categoryPath := filepath.Join(s.storagePath, category)
|
||||
if err := os.MkdirAll(categoryPath, 0755); err != nil {
|
||||
return "", fmt.Errorf("failed to create category directory: %w", err)
|
||||
}
|
||||
|
||||
// 生成唯一文件名
|
||||
ext := filepath.Ext(fileName)
|
||||
uniqueID := uuid.New().String()
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
newFileName := fmt.Sprintf("%s_%s%s", timestamp, uniqueID, ext)
|
||||
filePath := filepath.Join(categoryPath, newFileName)
|
||||
|
||||
// 创建文件
|
||||
dst, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to create file", "error", err, "path", filePath)
|
||||
return "", fmt.Errorf("创建文件失败: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// 写入文件
|
||||
if _, err := io.Copy(dst, file); err != nil {
|
||||
s.log.Errorw("Failed to write file", "error", err, "path", filePath)
|
||||
return "", fmt.Errorf("写入文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建访问URL
|
||||
fileURL := fmt.Sprintf("%s/%s/%s", s.baseURL, category, newFileName)
|
||||
|
||||
s.log.Infow("File uploaded successfully", "path", filePath, "url", fileURL)
|
||||
return fileURL, nil
|
||||
}
|
||||
|
||||
// UploadCharacterImage 上传角色图片
|
||||
func (s *UploadService) UploadCharacterImage(file io.Reader, fileName, contentType string) (string, error) {
|
||||
return s.UploadFile(file, fileName, contentType, "characters")
|
||||
}
|
||||
|
||||
// DeleteFile 删除本地文件
|
||||
func (s *UploadService) DeleteFile(fileURL string) error {
|
||||
// 从URL中提取相对路径
|
||||
// URL格式: http://localhost:8080/static/characters/20060102_150405_uuid.jpg
|
||||
relPath := s.extractRelativePathFromURL(fileURL)
|
||||
if relPath == "" {
|
||||
return fmt.Errorf("invalid file URL")
|
||||
}
|
||||
|
||||
filePath := filepath.Join(s.storagePath, relPath)
|
||||
err := os.Remove(filePath)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to delete file", "error", err, "path", filePath)
|
||||
return fmt.Errorf("删除文件失败: %w", err)
|
||||
}
|
||||
|
||||
s.log.Infow("File deleted successfully", "path", filePath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractRelativePathFromURL 从URL中提取相对路径
|
||||
func (s *UploadService) extractRelativePathFromURL(fileURL string) string {
|
||||
// 从baseURL后面提取路径
|
||||
// 例如: http://localhost:8080/static/characters/xxx.jpg -> characters/xxx.jpg
|
||||
if len(fileURL) <= len(s.baseURL) {
|
||||
return ""
|
||||
}
|
||||
return fileURL[len(s.baseURL)+1:] // +1 for the '/'
|
||||
}
|
||||
|
||||
// GetPresignedURL 本地存储不需要预签名URL,直接返回原URL
|
||||
func (s *UploadService) GetPresignedURL(objectName string, expiry time.Duration) (string, error) {
|
||||
// 本地存储通过静态文件服务直接访问,不需要预签名
|
||||
return fmt.Sprintf("%s/%s", s.baseURL, objectName), nil
|
||||
}
|
||||
566
application/services/video_generation_service.go
Normal file
566
application/services/video_generation_service.go
Normal file
@@ -0,0 +1,566 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
models "github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/infrastructure/storage"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"github.com/drama-generator/backend/pkg/video"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type VideoGenerationService struct {
|
||||
db *gorm.DB
|
||||
transferService *ResourceTransferService
|
||||
log *logger.Logger
|
||||
localStorage *storage.LocalStorage
|
||||
aiService *AIService
|
||||
}
|
||||
|
||||
func NewVideoGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, aiService *AIService, log *logger.Logger) *VideoGenerationService {
|
||||
service := &VideoGenerationService{
|
||||
db: db,
|
||||
localStorage: localStorage,
|
||||
transferService: transferService,
|
||||
aiService: aiService,
|
||||
log: log,
|
||||
}
|
||||
|
||||
go service.RecoverPendingTasks()
|
||||
|
||||
return service
|
||||
}
|
||||
|
||||
type GenerateVideoRequest struct {
|
||||
StoryboardID *uint `json:"storyboard_id"`
|
||||
DramaID string `json:"drama_id" binding:"required"`
|
||||
ImageGenID *uint `json:"image_gen_id"`
|
||||
|
||||
// 参考图模式:single, first_last, multiple, none
|
||||
ReferenceMode string `json:"reference_mode"`
|
||||
|
||||
// 单图模式
|
||||
ImageURL string `json:"image_url"`
|
||||
|
||||
// 首尾帧模式
|
||||
FirstFrameURL *string `json:"first_frame_url"`
|
||||
LastFrameURL *string `json:"last_frame_url"`
|
||||
|
||||
// 多图模式
|
||||
ReferenceImageURLs []string `json:"reference_image_urls"`
|
||||
|
||||
Prompt string `json:"prompt" binding:"required,min=5,max=2000"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
Duration *int `json:"duration"`
|
||||
FPS *int `json:"fps"`
|
||||
AspectRatio *string `json:"aspect_ratio"`
|
||||
Style *string `json:"style"`
|
||||
MotionLevel *int `json:"motion_level"`
|
||||
CameraMotion *string `json:"camera_motion"`
|
||||
Seed *int64 `json:"seed"`
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) GenerateVideo(request *GenerateVideoRequest) (*models.VideoGeneration, error) {
|
||||
if request.StoryboardID != nil {
|
||||
var storyboard models.Storyboard
|
||||
if err := s.db.Preload("Episode").Where("id = ?", *request.StoryboardID).First(&storyboard).Error; err != nil {
|
||||
return nil, fmt.Errorf("storyboard not found")
|
||||
}
|
||||
if fmt.Sprintf("%d", storyboard.Episode.DramaID) != request.DramaID {
|
||||
return nil, fmt.Errorf("storyboard does not belong to drama")
|
||||
}
|
||||
}
|
||||
|
||||
if request.ImageGenID != nil {
|
||||
var imageGen models.ImageGeneration
|
||||
if err := s.db.Where("id = ?", *request.ImageGenID).First(&imageGen).Error; err != nil {
|
||||
return nil, fmt.Errorf("image generation not found")
|
||||
}
|
||||
}
|
||||
|
||||
provider := request.Provider
|
||||
if provider == "" {
|
||||
provider = "doubao"
|
||||
}
|
||||
|
||||
dramaID, _ := strconv.ParseUint(request.DramaID, 10, 32)
|
||||
|
||||
videoGen := &models.VideoGeneration{
|
||||
StoryboardID: request.StoryboardID,
|
||||
DramaID: uint(dramaID),
|
||||
ImageGenID: request.ImageGenID,
|
||||
Provider: provider,
|
||||
Prompt: request.Prompt,
|
||||
Model: request.Model,
|
||||
Duration: request.Duration,
|
||||
FPS: request.FPS,
|
||||
AspectRatio: request.AspectRatio,
|
||||
Style: request.Style,
|
||||
MotionLevel: request.MotionLevel,
|
||||
CameraMotion: request.CameraMotion,
|
||||
Seed: request.Seed,
|
||||
Status: models.VideoStatusPending,
|
||||
}
|
||||
|
||||
// 根据参考图模式处理不同的参数
|
||||
if request.ReferenceMode != "" {
|
||||
videoGen.ReferenceMode = &request.ReferenceMode
|
||||
}
|
||||
|
||||
switch request.ReferenceMode {
|
||||
case "single":
|
||||
// 单图模式
|
||||
if request.ImageURL != "" {
|
||||
videoGen.ImageURL = &request.ImageURL
|
||||
}
|
||||
case "first_last":
|
||||
// 首尾帧模式
|
||||
if request.FirstFrameURL != nil {
|
||||
videoGen.FirstFrameURL = request.FirstFrameURL
|
||||
}
|
||||
if request.LastFrameURL != nil {
|
||||
videoGen.LastFrameURL = request.LastFrameURL
|
||||
}
|
||||
case "multiple":
|
||||
// 多图模式
|
||||
if len(request.ReferenceImageURLs) > 0 {
|
||||
referenceImagesJSON, err := json.Marshal(request.ReferenceImageURLs)
|
||||
if err == nil {
|
||||
referenceImagesStr := string(referenceImagesJSON)
|
||||
videoGen.ReferenceImageURLs = &referenceImagesStr
|
||||
}
|
||||
}
|
||||
case "none":
|
||||
// 无参考图,纯文本生成
|
||||
default:
|
||||
// 向后兼容:如果没有指定模式,根据提供的参数自动判断
|
||||
if request.ImageURL != "" {
|
||||
videoGen.ImageURL = &request.ImageURL
|
||||
mode := "single"
|
||||
videoGen.ReferenceMode = &mode
|
||||
} else if request.FirstFrameURL != nil || request.LastFrameURL != nil {
|
||||
videoGen.FirstFrameURL = request.FirstFrameURL
|
||||
videoGen.LastFrameURL = request.LastFrameURL
|
||||
mode := "first_last"
|
||||
videoGen.ReferenceMode = &mode
|
||||
} else if len(request.ReferenceImageURLs) > 0 {
|
||||
referenceImagesJSON, err := json.Marshal(request.ReferenceImageURLs)
|
||||
if err == nil {
|
||||
referenceImagesStr := string(referenceImagesJSON)
|
||||
videoGen.ReferenceImageURLs = &referenceImagesStr
|
||||
mode := "multiple"
|
||||
videoGen.ReferenceMode = &mode
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.db.Create(videoGen).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create record: %w", err)
|
||||
}
|
||||
|
||||
go s.ProcessVideoGeneration(videoGen.ID)
|
||||
|
||||
return videoGen, nil
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) ProcessVideoGeneration(videoGenID uint) {
|
||||
var videoGen models.VideoGeneration
|
||||
if err := s.db.First(&videoGen, videoGenID).Error; err != nil {
|
||||
s.log.Errorw("Failed to load video generation", "error", err, "id", videoGenID)
|
||||
return
|
||||
}
|
||||
|
||||
s.db.Model(&videoGen).Update("status", models.VideoStatusProcessing)
|
||||
|
||||
client, err := s.getVideoClient(videoGen.Provider, videoGen.Model)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to get video client", "error", err, "provider", videoGen.Provider, "model", videoGen.Model)
|
||||
s.updateVideoGenError(videoGenID, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Infow("Starting video generation", "id", videoGenID, "prompt", videoGen.Prompt, "provider", videoGen.Provider)
|
||||
|
||||
var opts []video.VideoOption
|
||||
if videoGen.Model != "" {
|
||||
opts = append(opts, video.WithModel(videoGen.Model))
|
||||
}
|
||||
if videoGen.Duration != nil {
|
||||
opts = append(opts, video.WithDuration(*videoGen.Duration))
|
||||
}
|
||||
if videoGen.FPS != nil {
|
||||
opts = append(opts, video.WithFPS(*videoGen.FPS))
|
||||
}
|
||||
if videoGen.AspectRatio != nil {
|
||||
opts = append(opts, video.WithAspectRatio(*videoGen.AspectRatio))
|
||||
}
|
||||
if videoGen.Style != nil {
|
||||
opts = append(opts, video.WithStyle(*videoGen.Style))
|
||||
}
|
||||
if videoGen.MotionLevel != nil {
|
||||
opts = append(opts, video.WithMotionLevel(*videoGen.MotionLevel))
|
||||
}
|
||||
if videoGen.CameraMotion != nil {
|
||||
opts = append(opts, video.WithCameraMotion(*videoGen.CameraMotion))
|
||||
}
|
||||
if videoGen.Seed != nil {
|
||||
opts = append(opts, video.WithSeed(*videoGen.Seed))
|
||||
}
|
||||
|
||||
// 根据参考图模式添加相应的选项
|
||||
if videoGen.ReferenceMode != nil {
|
||||
switch *videoGen.ReferenceMode {
|
||||
case "first_last":
|
||||
// 首尾帧模式
|
||||
if videoGen.FirstFrameURL != nil {
|
||||
opts = append(opts, video.WithFirstFrame(*videoGen.FirstFrameURL))
|
||||
}
|
||||
if videoGen.LastFrameURL != nil {
|
||||
opts = append(opts, video.WithLastFrame(*videoGen.LastFrameURL))
|
||||
}
|
||||
case "multiple":
|
||||
// 多图模式
|
||||
if videoGen.ReferenceImageURLs != nil {
|
||||
var imageURLs []string
|
||||
if err := json.Unmarshal([]byte(*videoGen.ReferenceImageURLs), &imageURLs); err == nil {
|
||||
opts = append(opts, video.WithReferenceImages(imageURLs))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 构造imageURL参数(单图模式使用,其他模式传空字符串)
|
||||
imageURL := ""
|
||||
if videoGen.ImageURL != nil {
|
||||
imageURL = *videoGen.ImageURL
|
||||
}
|
||||
|
||||
result, err := client.GenerateVideo(imageURL, videoGen.Prompt, opts...)
|
||||
if err != nil {
|
||||
s.log.Errorw("Video generation API call failed", "error", err, "id", videoGenID)
|
||||
s.updateVideoGenError(videoGenID, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if result.TaskID != "" {
|
||||
s.db.Model(&videoGen).Updates(map[string]interface{}{
|
||||
"task_id": result.TaskID,
|
||||
"status": models.VideoStatusProcessing,
|
||||
})
|
||||
go s.pollTaskStatus(videoGenID, result.TaskID, videoGen.Provider, videoGen.Model)
|
||||
return
|
||||
}
|
||||
|
||||
if result.VideoURL != "" {
|
||||
s.completeVideoGeneration(videoGenID, result.VideoURL, &result.Duration, &result.Width, &result.Height, nil)
|
||||
return
|
||||
}
|
||||
|
||||
s.updateVideoGenError(videoGenID, "no task ID or video URL returned")
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) pollTaskStatus(videoGenID uint, taskID string, provider string, model string) {
|
||||
client, err := s.getVideoClient(provider, model)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to get video client for polling", "error", err)
|
||||
s.updateVideoGenError(videoGenID, "failed to get video client")
|
||||
return
|
||||
}
|
||||
|
||||
maxAttempts := 300
|
||||
interval := 10 * time.Second
|
||||
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
time.Sleep(interval)
|
||||
|
||||
var videoGen models.VideoGeneration
|
||||
if err := s.db.First(&videoGen, videoGenID).Error; err != nil {
|
||||
s.log.Errorw("Failed to load video generation", "error", err, "id", videoGenID)
|
||||
return
|
||||
}
|
||||
|
||||
if videoGen.Status != models.VideoStatusProcessing {
|
||||
s.log.Infow("Video generation status changed, stopping poll", "id", videoGenID, "status", videoGen.Status)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := client.GetTaskStatus(taskID)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to get task status", "error", err, "task_id", taskID)
|
||||
continue
|
||||
}
|
||||
|
||||
if result.Completed {
|
||||
if result.VideoURL != "" {
|
||||
s.completeVideoGeneration(videoGenID, result.VideoURL, &result.Duration, &result.Width, &result.Height, nil)
|
||||
return
|
||||
}
|
||||
s.updateVideoGenError(videoGenID, "task completed but no video URL")
|
||||
return
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
s.updateVideoGenError(videoGenID, result.Error)
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Infow("Video generation in progress", "id", videoGenID, "attempt", attempt+1)
|
||||
}
|
||||
|
||||
s.updateVideoGenError(videoGenID, "polling timeout")
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) completeVideoGeneration(videoGenID uint, videoURL string, duration *int, width *int, height *int, firstFrameURL *string) {
|
||||
// 下载视频到本地存储(仅用于缓存,不更新数据库)
|
||||
if s.localStorage != nil && videoURL != "" {
|
||||
_, err := s.localStorage.DownloadFromURL(videoURL, "videos")
|
||||
if err != nil {
|
||||
s.log.Warnw("Failed to download video to local storage",
|
||||
"error", err,
|
||||
"id", videoGenID,
|
||||
"original_url", videoURL)
|
||||
} else {
|
||||
s.log.Infow("Video downloaded to local storage for caching",
|
||||
"id", videoGenID,
|
||||
"original_url", videoURL)
|
||||
}
|
||||
}
|
||||
|
||||
// 下载首帧图片到本地存储(仅用于缓存,不更新数据库)
|
||||
if firstFrameURL != nil && *firstFrameURL != "" && s.localStorage != nil {
|
||||
_, err := s.localStorage.DownloadFromURL(*firstFrameURL, "video_frames")
|
||||
if err != nil {
|
||||
s.log.Warnw("Failed to download first frame to local storage",
|
||||
"error", err,
|
||||
"id", videoGenID,
|
||||
"original_url", *firstFrameURL)
|
||||
} else {
|
||||
s.log.Infow("First frame downloaded to local storage for caching",
|
||||
"id", videoGenID,
|
||||
"original_url", *firstFrameURL)
|
||||
}
|
||||
}
|
||||
|
||||
// 数据库中保持使用原始URL
|
||||
updates := map[string]interface{}{
|
||||
"status": models.VideoStatusCompleted,
|
||||
"video_url": videoURL,
|
||||
}
|
||||
if duration != nil {
|
||||
updates["duration"] = *duration
|
||||
}
|
||||
if width != nil {
|
||||
updates["width"] = *width
|
||||
}
|
||||
if height != nil {
|
||||
updates["height"] = *height
|
||||
}
|
||||
if firstFrameURL != nil {
|
||||
updates["first_frame_url"] = *firstFrameURL
|
||||
}
|
||||
|
||||
if err := s.db.Model(&models.VideoGeneration{}).Where("id = ?", videoGenID).Updates(updates).Error; err != nil {
|
||||
s.log.Errorw("Failed to update video generation", "error", err, "id", videoGenID)
|
||||
return
|
||||
}
|
||||
|
||||
var videoGen models.VideoGeneration
|
||||
if err := s.db.First(&videoGen, videoGenID).Error; err == nil {
|
||||
if videoGen.StoryboardID != nil {
|
||||
if err := s.db.Model(&models.Storyboard{}).Where("id = ?", *videoGen.StoryboardID).Update("video_url", videoURL).Error; err != nil {
|
||||
s.log.Warnw("Failed to update storyboard video_url", "storyboard_id", *videoGen.StoryboardID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.log.Infow("Video generation completed", "id", videoGenID, "url", videoURL)
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) updateVideoGenError(videoGenID uint, errorMsg string) {
|
||||
if err := s.db.Model(&models.VideoGeneration{}).Where("id = ?", videoGenID).Updates(map[string]interface{}{
|
||||
"status": models.VideoStatusFailed,
|
||||
"error_msg": errorMsg,
|
||||
}).Error; err != nil {
|
||||
s.log.Errorw("Failed to update video generation error", "error", err, "id", videoGenID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) getVideoClient(provider string, modelName string) (video.VideoClient, error) {
|
||||
// 根据模型名称获取AI配置
|
||||
var config *models.AIServiceConfig
|
||||
var err error
|
||||
|
||||
if modelName != "" {
|
||||
config, err = s.aiService.GetConfigForModel("video", modelName)
|
||||
if err != nil {
|
||||
s.log.Warnw("Failed to get config for model, using default", "model", modelName, "error", err)
|
||||
config, err = s.aiService.GetDefaultConfig("video")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no video AI config found: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
config, err = s.aiService.GetDefaultConfig("video")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("no video AI config found: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 使用配置中的信息创建客户端
|
||||
baseURL := config.BaseURL
|
||||
apiKey := config.APIKey
|
||||
model := modelName
|
||||
if model == "" && len(config.Model) > 0 {
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
// 根据配置中的 provider 创建对应的客户端
|
||||
var endpoint string
|
||||
var queryEndpoint string
|
||||
|
||||
switch config.Provider {
|
||||
case "chatfire":
|
||||
endpoint = "/video/generations"
|
||||
queryEndpoint = "/video/task/{taskId}"
|
||||
return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
|
||||
case "doubao", "volcengine", "volces":
|
||||
endpoint = "/contents/generations/tasks"
|
||||
queryEndpoint = "/contents/generations/tasks/{taskId}"
|
||||
return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
|
||||
case "openai":
|
||||
// OpenAI Sora 使用 /v1/videos 端点
|
||||
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
|
||||
case "runway":
|
||||
return video.NewRunwayClient(baseURL, apiKey, model), nil
|
||||
case "pika":
|
||||
return video.NewPikaClient(baseURL, apiKey, model), nil
|
||||
case "minimax":
|
||||
return video.NewMinimaxClient(baseURL, apiKey, model), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported video provider: %s", provider)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) RecoverPendingTasks() {
|
||||
var pendingVideos []models.VideoGeneration
|
||||
if err := s.db.Where("status = ? AND task_id != ''", models.VideoStatusProcessing).Find(&pendingVideos).Error; err != nil {
|
||||
s.log.Errorw("Failed to load pending video tasks", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.log.Infow("Recovering pending video generation tasks", "count", len(pendingVideos))
|
||||
|
||||
for _, videoGen := range pendingVideos {
|
||||
go s.pollTaskStatus(videoGen.ID, *videoGen.TaskID, videoGen.Provider, videoGen.Model)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) GetVideoGeneration(id uint) (*models.VideoGeneration, error) {
|
||||
var videoGen models.VideoGeneration
|
||||
if err := s.db.First(&videoGen, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &videoGen, nil
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) ListVideoGenerations(dramaID *uint, storyboardID *uint, status string, limit int, offset int) ([]*models.VideoGeneration, int64, error) {
|
||||
var videos []*models.VideoGeneration
|
||||
var total int64
|
||||
|
||||
query := s.db.Model(&models.VideoGeneration{})
|
||||
|
||||
if dramaID != nil {
|
||||
query = query.Where("drama_id = ?", *dramaID)
|
||||
}
|
||||
if storyboardID != nil {
|
||||
query = query.Where("storyboard_id = ?", *storyboardID)
|
||||
}
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if err := query.Order("created_at DESC").Limit(limit).Offset(offset).Find(&videos).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return videos, total, nil
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) GenerateVideoFromImage(imageGenID uint) (*models.VideoGeneration, error) {
|
||||
var imageGen models.ImageGeneration
|
||||
if err := s.db.First(&imageGen, imageGenID).Error; err != nil {
|
||||
return nil, fmt.Errorf("image generation not found")
|
||||
}
|
||||
|
||||
if imageGen.Status != models.ImageStatusCompleted || imageGen.ImageURL == nil {
|
||||
return nil, fmt.Errorf("image is not ready")
|
||||
}
|
||||
|
||||
// 获取关联的Storyboard以获取时长
|
||||
var duration *int
|
||||
if imageGen.StoryboardID != nil {
|
||||
var storyboard models.Storyboard
|
||||
if err := s.db.Where("id = ?", *imageGen.StoryboardID).First(&storyboard).Error; err == nil {
|
||||
duration = &storyboard.Duration
|
||||
s.log.Infow("Using storyboard duration for video generation",
|
||||
"storyboard_id", *imageGen.StoryboardID,
|
||||
"duration", storyboard.Duration)
|
||||
}
|
||||
}
|
||||
|
||||
req := &GenerateVideoRequest{
|
||||
DramaID: fmt.Sprintf("%d", imageGen.DramaID),
|
||||
StoryboardID: imageGen.StoryboardID,
|
||||
ImageGenID: &imageGenID,
|
||||
ImageURL: *imageGen.ImageURL,
|
||||
Prompt: imageGen.Prompt,
|
||||
Provider: "doubao",
|
||||
Duration: duration,
|
||||
}
|
||||
|
||||
return s.GenerateVideo(req)
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) BatchGenerateVideosForEpisode(episodeID string) ([]*models.VideoGeneration, error) {
|
||||
var episode models.Episode
|
||||
if err := s.db.Preload("Storyboards").Where("id = ?", episodeID).First(&episode).Error; err != nil {
|
||||
return nil, fmt.Errorf("episode not found")
|
||||
}
|
||||
|
||||
var results []*models.VideoGeneration
|
||||
for _, storyboard := range episode.Storyboards {
|
||||
if storyboard.ImagePrompt == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
var imageGen models.ImageGeneration
|
||||
if err := s.db.Where("storyboard_id = ? AND status = ?", storyboard.ID, models.ImageStatusCompleted).
|
||||
Order("created_at DESC").First(&imageGen).Error; err != nil {
|
||||
s.log.Warnw("No completed image for storyboard", "storyboard_id", storyboard.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
videoGen, err := s.GenerateVideoFromImage(imageGen.ID)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to generate video", "storyboard_id", storyboard.ID, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
results = append(results, videoGen)
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *VideoGenerationService) DeleteVideoGeneration(id uint) error {
|
||||
return s.db.Delete(&models.VideoGeneration{}, id).Error
|
||||
}
|
||||
557
application/services/video_merge_service.go
Normal file
557
application/services/video_merge_service.go
Normal file
@@ -0,0 +1,557 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
models "github.com/drama-generator/backend/domain/models"
|
||||
"github.com/drama-generator/backend/infrastructure/external/ffmpeg"
|
||||
"github.com/drama-generator/backend/pkg/logger"
|
||||
"github.com/drama-generator/backend/pkg/video"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type VideoMergeService struct {
|
||||
db *gorm.DB
|
||||
aiService *AIService
|
||||
transferService *ResourceTransferService
|
||||
ffmpeg *ffmpeg.FFmpeg
|
||||
storagePath string
|
||||
baseURL string
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
func NewVideoMergeService(db *gorm.DB, transferService *ResourceTransferService, storagePath, baseURL string, log *logger.Logger) *VideoMergeService {
|
||||
return &VideoMergeService{
|
||||
db: db,
|
||||
aiService: NewAIService(db, log),
|
||||
transferService: transferService,
|
||||
ffmpeg: ffmpeg.NewFFmpeg(log),
|
||||
storagePath: storagePath,
|
||||
baseURL: baseURL,
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
type MergeVideoRequest struct {
|
||||
EpisodeID string `json:"episode_id" binding:"required"`
|
||||
DramaID string `json:"drama_id" binding:"required"`
|
||||
Title string `json:"title"`
|
||||
Scenes []models.SceneClip `json:"scenes" binding:"required,min=1"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
|
||||
func (s *VideoMergeService) MergeVideos(req *MergeVideoRequest) (*models.VideoMerge, error) {
|
||||
// 验证episode权限
|
||||
var episode models.Episode
|
||||
if err := s.db.Preload("Drama").Where("id = ?", req.EpisodeID).First(&episode).Error; err != nil {
|
||||
return nil, fmt.Errorf("episode not found")
|
||||
}
|
||||
|
||||
// 验证所有场景都有视频
|
||||
for i, scene := range req.Scenes {
|
||||
if scene.VideoURL == "" {
|
||||
return nil, fmt.Errorf("scene %d has no video", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
provider := req.Provider
|
||||
if provider == "" {
|
||||
provider = "doubao"
|
||||
}
|
||||
|
||||
// 序列化场景列表
|
||||
scenesJSON, err := json.Marshal(req.Scenes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to serialize scenes: %w", err)
|
||||
}
|
||||
|
||||
s.log.Infow("Serialized scenes to JSON",
|
||||
"scenes_count", len(req.Scenes),
|
||||
"scenes_json", string(scenesJSON))
|
||||
|
||||
epID, _ := strconv.ParseUint(req.EpisodeID, 10, 32)
|
||||
dramaID, _ := strconv.ParseUint(req.DramaID, 10, 32)
|
||||
|
||||
videoMerge := &models.VideoMerge{
|
||||
EpisodeID: uint(epID),
|
||||
DramaID: uint(dramaID),
|
||||
Title: req.Title,
|
||||
Provider: provider,
|
||||
Model: &req.Model,
|
||||
Scenes: scenesJSON,
|
||||
Status: models.VideoMergeStatusPending,
|
||||
}
|
||||
|
||||
if err := s.db.Create(videoMerge).Error; err != nil {
|
||||
return nil, fmt.Errorf("failed to create merge record: %w", err)
|
||||
}
|
||||
|
||||
go s.processMergeVideo(videoMerge.ID)
|
||||
|
||||
return videoMerge, nil
|
||||
}
|
||||
|
||||
func (s *VideoMergeService) processMergeVideo(mergeID uint) {
|
||||
var videoMerge models.VideoMerge
|
||||
if err := s.db.First(&videoMerge, mergeID).Error; err != nil {
|
||||
s.log.Errorw("Failed to load video merge", "error", err, "id", mergeID)
|
||||
return
|
||||
}
|
||||
|
||||
s.db.Model(&videoMerge).Update("status", models.VideoMergeStatusProcessing)
|
||||
|
||||
client, err := s.getVideoClient(videoMerge.Provider)
|
||||
if err != nil {
|
||||
s.updateMergeError(mergeID, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 解析场景列表
|
||||
var scenes []models.SceneClip
|
||||
if err := json.Unmarshal(videoMerge.Scenes, &scenes); err != nil {
|
||||
s.updateMergeError(mergeID, fmt.Sprintf("failed to parse scenes: %v", err))
|
||||
return
|
||||
}
|
||||
|
||||
// 调用视频合并API
|
||||
result, err := s.mergeVideoClips(client, scenes)
|
||||
if err != nil {
|
||||
s.updateMergeError(mergeID, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if !result.Completed {
|
||||
s.db.Model(&videoMerge).Updates(map[string]interface{}{
|
||||
"status": models.VideoMergeStatusProcessing,
|
||||
"task_id": result.TaskID,
|
||||
})
|
||||
go s.pollMergeStatus(mergeID, client, result.TaskID)
|
||||
return
|
||||
}
|
||||
|
||||
s.completeMerge(mergeID, result)
|
||||
}
|
||||
|
||||
func (s *VideoMergeService) mergeVideoClips(client video.VideoClient, scenes []models.SceneClip) (*video.VideoResult, error) {
|
||||
if len(scenes) == 0 {
|
||||
return nil, fmt.Errorf("no scenes to merge")
|
||||
}
|
||||
|
||||
// 按Order字段排序场景
|
||||
sort.Slice(scenes, func(i, j int) bool {
|
||||
return scenes[i].Order < scenes[j].Order
|
||||
})
|
||||
|
||||
s.log.Infow("Merging video clips with FFmpeg", "scene_count", len(scenes))
|
||||
|
||||
// 计算总时长
|
||||
var totalDuration float64
|
||||
for _, scene := range scenes {
|
||||
totalDuration += scene.Duration
|
||||
}
|
||||
|
||||
// 准备FFmpeg合成选项
|
||||
clips := make([]ffmpeg.VideoClip, len(scenes))
|
||||
for i, scene := range scenes {
|
||||
clips[i] = ffmpeg.VideoClip{
|
||||
URL: scene.VideoURL,
|
||||
Duration: scene.Duration,
|
||||
StartTime: scene.StartTime,
|
||||
EndTime: scene.EndTime,
|
||||
Transition: scene.Transition,
|
||||
}
|
||||
|
||||
s.log.Infow("Clip added to merge queue",
|
||||
"order", scene.Order,
|
||||
"index", i,
|
||||
"duration", scene.Duration,
|
||||
"start_time", scene.StartTime,
|
||||
"end_time", scene.EndTime)
|
||||
}
|
||||
|
||||
// 创建视频输出目录
|
||||
videoDir := filepath.Join(s.storagePath, "videos", "merged")
|
||||
if err := os.MkdirAll(videoDir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create video directory: %w", err)
|
||||
}
|
||||
|
||||
// 生成输出文件名
|
||||
fileName := fmt.Sprintf("merged_%d.mp4", time.Now().Unix())
|
||||
outputPath := filepath.Join(videoDir, fileName)
|
||||
|
||||
// 使用FFmpeg合成视频
|
||||
mergedPath, err := s.ffmpeg.MergeVideos(&ffmpeg.MergeOptions{
|
||||
OutputPath: outputPath,
|
||||
Clips: clips,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ffmpeg merge failed: %w", err)
|
||||
}
|
||||
|
||||
s.log.Infow("Video merged successfully", "path", mergedPath)
|
||||
|
||||
// 生成访问URL(相对路径)
|
||||
relPath := filepath.Join("videos", "merged", fileName)
|
||||
videoURL := fmt.Sprintf("%s/%s", s.baseURL, relPath)
|
||||
|
||||
result := &video.VideoResult{
|
||||
VideoURL: videoURL, // 返回可访问的URL
|
||||
Duration: int(totalDuration),
|
||||
Completed: true,
|
||||
Status: "completed",
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *VideoMergeService) pollMergeStatus(mergeID uint, client video.VideoClient, taskID string) {
|
||||
maxAttempts := 240
|
||||
pollInterval := 5 * time.Second
|
||||
|
||||
for i := 0; i < maxAttempts; i++ {
|
||||
time.Sleep(pollInterval)
|
||||
|
||||
result, err := client.GetTaskStatus(taskID)
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to get merge task status", "error", err, "task_id", taskID)
|
||||
continue
|
||||
}
|
||||
|
||||
if result.Completed {
|
||||
s.completeMerge(mergeID, result)
|
||||
return
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
s.updateMergeError(mergeID, result.Error)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.updateMergeError(mergeID, "timeout: video merge took too long")
|
||||
}
|
||||
|
||||
func (s *VideoMergeService) completeMerge(mergeID uint, result *video.VideoResult) {
|
||||
now := time.Now()
|
||||
|
||||
// 获取merge记录
|
||||
var videoMerge models.VideoMerge
|
||||
if err := s.db.First(&videoMerge, mergeID).Error; err != nil {
|
||||
s.log.Errorw("Failed to load video merge for completion", "error", err, "id", mergeID)
|
||||
return
|
||||
}
|
||||
|
||||
finalVideoURL := result.VideoURL
|
||||
|
||||
// 使用本地存储,不再使用MinIO
|
||||
s.log.Infow("Video merge completed, using local storage", "merge_id", mergeID, "local_path", result.VideoURL)
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"status": models.VideoMergeStatusCompleted,
|
||||
"merged_url": finalVideoURL,
|
||||
"completed_at": now,
|
||||
}
|
||||
|
||||
if result.Duration > 0 {
|
||||
updates["duration"] = result.Duration
|
||||
}
|
||||
|
||||
s.db.Model(&models.VideoMerge{}).Where("id = ?", mergeID).Updates(updates)
|
||||
|
||||
// 更新episode的状态和最终视频URL
|
||||
if videoMerge.EpisodeID != 0 {
|
||||
s.db.Model(&models.Episode{}).Where("id = ?", videoMerge.EpisodeID).Updates(map[string]interface{}{
|
||||
"status": "completed",
|
||||
"video_url": finalVideoURL,
|
||||
})
|
||||
s.log.Infow("Episode finalized", "episode_id", videoMerge.EpisodeID, "video_url", finalVideoURL)
|
||||
}
|
||||
|
||||
s.log.Infow("Video merge completed", "id", mergeID, "url", finalVideoURL)
|
||||
}
|
||||
|
||||
func (s *VideoMergeService) updateMergeError(mergeID uint, errorMsg string) {
|
||||
s.db.Model(&models.VideoMerge{}).Where("id = ?", mergeID).Updates(map[string]interface{}{
|
||||
"status": models.VideoMergeStatusFailed,
|
||||
"error_msg": errorMsg,
|
||||
})
|
||||
s.log.Errorw("Video merge failed", "id", mergeID, "error", errorMsg)
|
||||
}
|
||||
|
||||
func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient, error) {
|
||||
config, err := s.aiService.GetDefaultConfig("video")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get video config: %w", err)
|
||||
}
|
||||
|
||||
// 使用第一个模型
|
||||
model := ""
|
||||
if len(config.Model) > 0 {
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
// 根据配置中的 provider 创建对应的客户端
|
||||
var endpoint string
|
||||
var queryEndpoint string
|
||||
|
||||
switch config.Provider {
|
||||
case "runway":
|
||||
return video.NewRunwayClient(config.BaseURL, config.APIKey, model), nil
|
||||
case "pika":
|
||||
return video.NewPikaClient(config.BaseURL, config.APIKey, model), nil
|
||||
case "openai", "sora":
|
||||
return video.NewOpenAISoraClient(config.BaseURL, config.APIKey, model), nil
|
||||
case "minimax":
|
||||
return video.NewMinimaxClient(config.BaseURL, config.APIKey, model), nil
|
||||
case "chatfire":
|
||||
endpoint = "/video/generations"
|
||||
queryEndpoint = "/video/task/{taskId}"
|
||||
return video.NewChatfireClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
case "doubao", "volces", "ark":
|
||||
endpoint = "/contents/generations/tasks"
|
||||
queryEndpoint = "/generations/tasks/{taskId}"
|
||||
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
default:
|
||||
endpoint = "/contents/generations/tasks"
|
||||
queryEndpoint = "/generations/tasks/{taskId}"
|
||||
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *VideoMergeService) GetMerge(mergeID uint) (*models.VideoMerge, error) {
|
||||
var merge models.VideoMerge
|
||||
if err := s.db.Where("id = ? ", mergeID).First(&merge).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &merge, nil
|
||||
}
|
||||
|
||||
func (s *VideoMergeService) ListMerges(episodeID *string, status string, page, pageSize int) ([]models.VideoMerge, int64, error) {
|
||||
query := s.db.Model(&models.VideoMerge{})
|
||||
|
||||
if episodeID != nil && *episodeID != "" {
|
||||
query = query.Where("episode_id = ?", *episodeID)
|
||||
}
|
||||
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var merges []models.VideoMerge
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&merges).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return merges, total, nil
|
||||
}
|
||||
|
||||
func (s *VideoMergeService) DeleteMerge(mergeID uint) error {
|
||||
result := s.db.Where("id = ? ", mergeID).Delete(&models.VideoMerge{})
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("merge not found")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TimelineClip 时间线片段数据
|
||||
type TimelineClip struct {
|
||||
AssetID string `json:"asset_id"` // 素材库视频ID(优先使用)
|
||||
StoryboardID string `json:"storyboard_id"` // 分镜ID(fallback)
|
||||
Order int `json:"order"`
|
||||
StartTime float64 `json:"start_time"`
|
||||
EndTime float64 `json:"end_time"`
|
||||
Duration float64 `json:"duration"`
|
||||
Transition map[string]interface{} `json:"transition"`
|
||||
}
|
||||
|
||||
// FinalizeEpisodeRequest 完成剧集制作请求
|
||||
type FinalizeEpisodeRequest struct {
|
||||
EpisodeID string `json:"episode_id"`
|
||||
Clips []TimelineClip `json:"clips"`
|
||||
}
|
||||
|
||||
// FinalizeEpisode 完成集数制作,根据时间线场景顺序合成最终视频
|
||||
func (s *VideoMergeService) FinalizeEpisode(episodeID string, timelineData *FinalizeEpisodeRequest) (map[string]interface{}, error) {
|
||||
// 验证episode存在且属于该用户
|
||||
var episode models.Episode
|
||||
if err := s.db.Preload("Drama").Preload("Storyboards").Where("id = ?", episodeID).First(&episode).Error; err != nil {
|
||||
return nil, fmt.Errorf("episode not found")
|
||||
}
|
||||
|
||||
// 构建分镜ID映射
|
||||
sceneMap := make(map[string]models.Storyboard)
|
||||
for _, scene := range episode.Storyboards {
|
||||
sceneMap[fmt.Sprintf("%d", scene.ID)] = scene
|
||||
}
|
||||
|
||||
// 根据时间线数据构建场景片段
|
||||
var sceneClips []models.SceneClip
|
||||
var skippedScenes []int
|
||||
|
||||
if timelineData != nil && len(timelineData.Clips) > 0 {
|
||||
// 使用前端提供的时间线数据
|
||||
for _, clip := range timelineData.Clips {
|
||||
// 优先使用素材库中的视频(通过AssetID)
|
||||
var videoURL string
|
||||
var sceneID uint
|
||||
|
||||
if clip.AssetID != "" {
|
||||
// 从素材库获取视频URL
|
||||
var asset models.Asset
|
||||
if err := s.db.Where("id = ? AND type = ?", clip.AssetID, models.AssetTypeVideo).First(&asset).Error; err == nil {
|
||||
videoURL = asset.URL
|
||||
// 如果asset关联了storyboard,使用关联的storyboard_id
|
||||
if asset.StoryboardID != nil {
|
||||
sceneID = *asset.StoryboardID
|
||||
}
|
||||
s.log.Infow("Using video from asset library", "asset_id", clip.AssetID, "video_url", videoURL)
|
||||
} else {
|
||||
s.log.Warnw("Asset not found, will try storyboard video", "asset_id", clip.AssetID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有从素材库获取到视频,尝试从storyboard获取
|
||||
if videoURL == "" && clip.StoryboardID != "" {
|
||||
scene, exists := sceneMap[clip.StoryboardID]
|
||||
if !exists {
|
||||
s.log.Warnw("Storyboard not found in episode, skipping", "storyboard_id", clip.StoryboardID)
|
||||
continue
|
||||
}
|
||||
|
||||
if scene.VideoURL != nil && *scene.VideoURL != "" {
|
||||
videoURL = *scene.VideoURL
|
||||
sceneID = scene.ID
|
||||
s.log.Infow("Using video from storyboard", "storyboard_id", clip.StoryboardID, "video_url", videoURL)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果仍然没有视频URL,跳过该片段
|
||||
if videoURL == "" {
|
||||
s.log.Warnw("No video available for clip, skipping", "clip", clip)
|
||||
if clip.StoryboardID != "" {
|
||||
if scene, exists := sceneMap[clip.StoryboardID]; exists {
|
||||
skippedScenes = append(skippedScenes, scene.StoryboardNumber)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
sceneClip := models.SceneClip{
|
||||
SceneID: sceneID,
|
||||
VideoURL: videoURL,
|
||||
Duration: clip.Duration,
|
||||
Order: clip.Order,
|
||||
StartTime: clip.StartTime,
|
||||
EndTime: clip.EndTime,
|
||||
Transition: clip.Transition,
|
||||
}
|
||||
s.log.Infow("Adding scene clip with transition",
|
||||
"scene_id", sceneID,
|
||||
"order", clip.Order,
|
||||
"transition", clip.Transition)
|
||||
sceneClips = append(sceneClips, sceneClip)
|
||||
}
|
||||
} else {
|
||||
// 没有时间线数据,使用默认场景顺序
|
||||
if len(episode.Storyboards) == 0 {
|
||||
return nil, fmt.Errorf("no scenes found for this episode")
|
||||
}
|
||||
|
||||
order := 0
|
||||
for _, scene := range episode.Storyboards {
|
||||
// 优先从素材库查找该分镜关联的视频
|
||||
var videoURL string
|
||||
var asset models.Asset
|
||||
if err := s.db.Where("storyboard_id = ? AND type = ? AND episode_id = ?",
|
||||
scene.ID, models.AssetTypeVideo, episode.ID).
|
||||
Order("created_at DESC").
|
||||
First(&asset).Error; err == nil {
|
||||
videoURL = asset.URL
|
||||
s.log.Infow("Using video from asset library for storyboard",
|
||||
"storyboard_id", scene.ID,
|
||||
"asset_id", asset.ID,
|
||||
"video_url", videoURL)
|
||||
} else if scene.VideoURL != nil && *scene.VideoURL != "" {
|
||||
// 如果素材库没有,使用storyboard的video_url作为fallback
|
||||
videoURL = *scene.VideoURL
|
||||
s.log.Infow("Using fallback video from storyboard",
|
||||
"storyboard_id", scene.ID,
|
||||
"video_url", videoURL)
|
||||
}
|
||||
|
||||
// 跳过没有视频的场景
|
||||
if videoURL == "" {
|
||||
s.log.Warnw("Scene has no video, skipping", "storyboard_number", scene.StoryboardNumber)
|
||||
skippedScenes = append(skippedScenes, scene.StoryboardNumber)
|
||||
continue
|
||||
}
|
||||
|
||||
clip := models.SceneClip{
|
||||
SceneID: scene.ID,
|
||||
VideoURL: videoURL,
|
||||
Duration: float64(scene.Duration),
|
||||
Order: order,
|
||||
}
|
||||
sceneClips = append(sceneClips, clip)
|
||||
order++
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否至少有一个场景可以合成
|
||||
if len(sceneClips) == 0 {
|
||||
return nil, fmt.Errorf("no scenes with videos available for merging")
|
||||
}
|
||||
|
||||
// 创建视频合成任务
|
||||
title := fmt.Sprintf("%s - 第%d集", episode.Drama.Title, episode.EpisodeNum)
|
||||
|
||||
finalReq := &MergeVideoRequest{
|
||||
EpisodeID: episodeID,
|
||||
DramaID: fmt.Sprintf("%d", episode.DramaID),
|
||||
Title: title,
|
||||
Scenes: sceneClips,
|
||||
Provider: "doubao", // 默认使用doubao
|
||||
}
|
||||
|
||||
// 执行视频合成
|
||||
videoMerge, err := s.MergeVideos(finalReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to start video merge: %w", err)
|
||||
}
|
||||
|
||||
// 更新episode状态为processing
|
||||
s.db.Model(&episode).Updates(map[string]interface{}{
|
||||
"status": "processing",
|
||||
})
|
||||
|
||||
result := map[string]interface{}{
|
||||
"message": "视频合成任务已创建,正在后台处理",
|
||||
"merge_id": videoMerge.ID,
|
||||
"episode_id": episodeID,
|
||||
"scenes_count": len(sceneClips),
|
||||
}
|
||||
|
||||
// 如果有跳过的场景,添加提示信息
|
||||
if len(skippedScenes) > 0 {
|
||||
result["skipped_scenes"] = skippedScenes
|
||||
result["warning"] = fmt.Sprintf("已跳过 %d 个未生成视频的场景(场景编号:%v)", len(skippedScenes), skippedScenes)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
Reference in New Issue
Block a user