create project

This commit is contained in:
2026-01-16 17:30:40 +08:00
commit effac6b017
157 changed files with 45997 additions and 0 deletions

View File

@@ -0,0 +1,398 @@
package services
import (
"errors"
"fmt"
"github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/ai"
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
type AIService struct {
db *gorm.DB
log *logger.Logger
}
func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
return &AIService{
db: db,
log: log,
}
}
type CreateAIConfigRequest struct {
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
Name string `json:"name" binding:"required,min=1,max=100"`
Provider string `json:"provider" binding:"required"`
BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"`
Endpoint string `json:"endpoint"`
QueryEndpoint string `json:"query_endpoint"`
Priority int `json:"priority"`
IsDefault bool `json:"is_default"`
Settings string `json:"settings"`
}
type UpdateAIConfigRequest struct {
Name string `json:"name" binding:"omitempty,min=1,max=100"`
Provider string `json:"provider"`
BaseURL string `json:"base_url" binding:"omitempty,url"`
APIKey string `json:"api_key"`
Model *models.ModelField `json:"model"`
Endpoint string `json:"endpoint"`
QueryEndpoint string `json:"query_endpoint"`
Priority *int `json:"priority"`
IsDefault bool `json:"is_default"`
IsActive bool `json:"is_active"`
Settings string `json:"settings"`
}
type TestConnectionRequest struct {
BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"`
Provider string `json:"provider"`
Endpoint string `json:"endpoint"`
}
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
// 根据 provider 和 service_type 自动设置 endpoint
endpoint := req.Endpoint
queryEndpoint := req.QueryEndpoint
if endpoint == "" {
switch req.Provider {
case "gemini", "google":
if req.ServiceType == "text" {
endpoint = "/v1beta/models/{model}:generateContent"
} else if req.ServiceType == "image" {
endpoint = "/v1beta/models/{model}:generateContent"
}
case "openai":
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
} else if req.ServiceType == "video" {
endpoint = "/videos"
if queryEndpoint == "" {
queryEndpoint = "/videos/{taskId}"
}
}
case "chatfire":
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
} else if req.ServiceType == "video" {
endpoint = "/video/generations"
if queryEndpoint == "" {
queryEndpoint = "/video/task/{taskId}"
}
}
case "doubao", "volcengine", "volces":
if req.ServiceType == "video" {
endpoint = "/contents/generations/tasks"
if queryEndpoint == "" {
queryEndpoint = "/generations/tasks/{taskId}"
}
}
default:
// 默认使用 OpenAI 格式
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
}
}
}
config := &models.AIServiceConfig{
ServiceType: req.ServiceType,
Name: req.Name,
Provider: req.Provider,
BaseURL: req.BaseURL,
APIKey: req.APIKey,
Model: req.Model,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
Priority: req.Priority,
IsDefault: req.IsDefault,
IsActive: true,
Settings: req.Settings,
}
if err := s.db.Create(config).Error; err != nil {
s.log.Errorw("Failed to create AI config", "error", err)
return nil, err
}
s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint)
return config, nil
}
func (s *AIService) GetConfig(configID uint) (*models.AIServiceConfig, error) {
var config models.AIServiceConfig
err := s.db.Where("id = ? ", configID).First(&config).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("config not found")
}
return nil, err
}
return &config, nil
}
func (s *AIService) ListConfigs(serviceType string) ([]models.AIServiceConfig, error) {
var configs []models.AIServiceConfig
query := s.db
if serviceType != "" {
query = query.Where("service_type = ?", serviceType)
}
err := query.Order("priority DESC, created_at DESC").Find(&configs).Error
if err != nil {
s.log.Errorw("Failed to list AI configs", "error", err)
return nil, err
}
return configs, nil
}
func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*models.AIServiceConfig, error) {
var config models.AIServiceConfig
if err := s.db.Where("id = ? ", configID).First(&config).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("config not found")
}
return nil, err
}
tx := s.db.Begin()
// 不再需要is_default独占逻辑
updates := make(map[string]interface{})
if req.Name != "" {
updates["name"] = req.Name
}
if req.Provider != "" {
updates["provider"] = req.Provider
}
if req.BaseURL != "" {
updates["base_url"] = req.BaseURL
}
if req.APIKey != "" {
updates["api_key"] = req.APIKey
}
if req.Model != nil && len(*req.Model) > 0 {
updates["model"] = *req.Model
}
if req.Priority != nil {
updates["priority"] = *req.Priority
}
// 如果提供了 provider根据 provider 和 service_type 自动设置 endpoint
if req.Provider != "" && req.Endpoint == "" {
provider := req.Provider
serviceType := config.ServiceType
switch provider {
case "gemini", "google":
if serviceType == "text" || serviceType == "image" {
updates["endpoint"] = "/v1beta/models/{model}:generateContent"
}
case "openai":
if serviceType == "text" {
updates["endpoint"] = "/chat/completions"
} else if serviceType == "image" {
updates["endpoint"] = "/images/generations"
} else if serviceType == "video" {
updates["endpoint"] = "/videos"
updates["query_endpoint"] = "/videos/{taskId}"
}
case "chatfire":
if serviceType == "text" {
updates["endpoint"] = "/chat/completions"
} else if serviceType == "image" {
updates["endpoint"] = "/images/generations"
} else if serviceType == "video" {
updates["endpoint"] = "/video/generations"
updates["query_endpoint"] = "/video/task/{taskId}"
}
}
} else if req.Endpoint != "" {
updates["endpoint"] = req.Endpoint
}
// 允许清空query_endpoint所以不检查是否为空
updates["query_endpoint"] = req.QueryEndpoint
if req.Settings != "" {
updates["settings"] = req.Settings
}
updates["is_default"] = req.IsDefault
updates["is_active"] = req.IsActive
if err := tx.Model(&config).Updates(updates).Error; err != nil {
tx.Rollback()
s.log.Errorw("Failed to update AI config", "error", err)
return nil, err
}
if err := tx.Commit().Error; err != nil {
return nil, err
}
s.log.Infow("AI config updated", "config_id", configID)
return &config, nil
}
func (s *AIService) DeleteConfig(configID uint) error {
result := s.db.Where("id = ? ", configID).Delete(&models.AIServiceConfig{})
if result.Error != nil {
s.log.Errorw("Failed to delete AI config", "error", result.Error)
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("config not found")
}
s.log.Infow("AI config deleted", "config_id", configID)
return nil
}
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model))
// 使用第一个模型进行测试
model := ""
if len(req.Model) > 0 {
model = req.Model[0]
}
s.log.Infow("Using model for test", "model", model, "provider", req.Provider)
// 根据 provider 参数选择客户端
var client ai.AIClient
var endpoint string
switch req.Provider {
case "gemini", "google":
// Gemini
s.log.Infow("Using Gemini client", "baseURL", req.BaseURL)
endpoint = "/v1beta/models/{model}:generateContent"
client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint)
case "openai", "chatfire":
// OpenAI 格式(包括 chatfire 等)
s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider)
endpoint = req.Endpoint
if endpoint == "" {
endpoint = "/chat/completions"
}
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
default:
// 默认使用 OpenAI 格式
s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL)
endpoint = req.Endpoint
if endpoint == "" {
endpoint = "/chat/completions"
}
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
}
s.log.Infow("Calling TestConnection on client", "endpoint", endpoint)
err := client.TestConnection()
if err != nil {
s.log.Errorw("TestConnection failed", "error", err)
} else {
s.log.Infow("TestConnection succeeded")
}
return err
}
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
var config models.AIServiceConfig
// 按优先级降序获取第一个配置
err := s.db.Where("service_type = ?", serviceType).
Order("priority DESC, created_at DESC").
First(&config).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("no config found")
}
return nil, err
}
return &config, nil
}
// GetConfigForModel 根据服务类型和模型名称获取优先级最高的配置
func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*models.AIServiceConfig, error) {
var configs []models.AIServiceConfig
err := s.db.Where("service_type = ?", serviceType).
Order("priority DESC, created_at DESC").
Find(&configs).Error
if err != nil {
return nil, err
}
// 查找包含指定模型的配置
for _, config := range configs {
for _, model := range config.Model {
if model == modelName {
return &config, nil
}
}
}
return nil, errors.New("no config found for model: " + modelName)
}
func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) {
config, err := s.GetDefaultConfig(serviceType)
if err != nil {
return nil, err
}
// 使用第一个模型
model := ""
if len(config.Model) > 0 {
model = config.Model[0]
}
// 使用数据库配置中的 endpoint如果为空则根据 provider 设置默认值
endpoint := config.Endpoint
if endpoint == "" {
switch config.Provider {
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
default:
endpoint = "/chat/completions"
}
}
// 根据 provider 创建对应的客户端
switch config.Provider {
case "gemini", "google":
return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
// openai, chatfire 等其他厂商都使用 OpenAI 格式
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
}
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {
client, err := s.GetAIClient("text")
if err != nil {
return "", fmt.Errorf("failed to get AI client: %w", err)
}
return client.GenerateText(prompt, systemPrompt, options...)
}

View File

@@ -0,0 +1,287 @@
package services
import (
"fmt"
"strconv"
"strings"
models "github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
type AssetService struct {
db *gorm.DB
log *logger.Logger
}
func NewAssetService(db *gorm.DB, log *logger.Logger) *AssetService {
return &AssetService{
db: db,
log: log,
}
}
type CreateAssetRequest struct {
DramaID *string `json:"drama_id"`
Name string `json:"name" binding:"required"`
Description *string `json:"description"`
Type models.AssetType `json:"type" binding:"required"`
Category *string `json:"category"`
URL string `json:"url" binding:"required"`
ThumbnailURL *string `json:"thumbnail_url"`
LocalPath *string `json:"local_path"`
FileSize *int64 `json:"file_size"`
MimeType *string `json:"mime_type"`
Width *int `json:"width"`
Height *int `json:"height"`
Duration *int `json:"duration"`
Format *string `json:"format"`
ImageGenID *uint `json:"image_gen_id"`
VideoGenID *uint `json:"video_gen_id"`
TagIDs []uint `json:"tag_ids"`
}
type UpdateAssetRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
Category *string `json:"category"`
ThumbnailURL *string `json:"thumbnail_url"`
TagIDs []uint `json:"tag_ids"`
IsFavorite *bool `json:"is_favorite"`
}
type ListAssetsRequest struct {
DramaID *string `json:"drama_id"`
EpisodeID *uint `json:"episode_id"`
StoryboardID *uint `json:"storyboard_id"`
Type *models.AssetType `json:"type"`
Category string `json:"category"`
TagIDs []uint `json:"tag_ids"`
IsFavorite *bool `json:"is_favorite"`
Search string `json:"search"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
func (s *AssetService) CreateAsset(req *CreateAssetRequest) (*models.Asset, error) {
var dramaID *uint
if req.DramaID != nil && *req.DramaID != "" {
id, err := strconv.ParseUint(*req.DramaID, 10, 32)
if err == nil {
uid := uint(id)
dramaID = &uid
}
}
if dramaID != nil {
var drama models.Drama
if err := s.db.Where("id = ?", *dramaID).First(&drama).Error; err != nil {
return nil, fmt.Errorf("drama not found")
}
}
asset := &models.Asset{
DramaID: dramaID,
Name: req.Name,
Description: req.Description,
Type: req.Type,
Category: req.Category,
URL: req.URL,
ThumbnailURL: req.ThumbnailURL,
LocalPath: req.LocalPath,
FileSize: req.FileSize,
MimeType: req.MimeType,
Width: req.Width,
Height: req.Height,
Duration: req.Duration,
Format: req.Format,
ImageGenID: req.ImageGenID,
VideoGenID: req.VideoGenID,
}
if err := s.db.Create(asset).Error; err != nil {
return nil, fmt.Errorf("failed to create asset: %w", err)
}
return asset, nil
}
func (s *AssetService) UpdateAsset(assetID uint, req *UpdateAssetRequest) (*models.Asset, error) {
var asset models.Asset
if err := s.db.Where("id = ?", assetID).First(&asset).Error; err != nil {
return nil, fmt.Errorf("asset not found")
}
updates := make(map[string]interface{})
if req.Name != nil {
updates["name"] = *req.Name
}
if req.Description != nil {
updates["description"] = *req.Description
}
if req.Category != nil {
updates["category"] = *req.Category
}
if req.ThumbnailURL != nil {
updates["thumbnail_url"] = *req.ThumbnailURL
}
if req.IsFavorite != nil {
updates["is_favorite"] = *req.IsFavorite
}
if len(updates) > 0 {
if err := s.db.Model(&asset).Updates(updates).Error; err != nil {
return nil, fmt.Errorf("failed to update asset: %w", err)
}
}
if err := s.db.First(&asset, assetID).Error; err != nil {
return nil, err
}
return &asset, nil
}
func (s *AssetService) GetAsset(assetID uint) (*models.Asset, error) {
var asset models.Asset
if err := s.db.Where("id = ? ", assetID).First(&asset).Error; err != nil {
return nil, err
}
s.db.Model(&asset).UpdateColumn("view_count", gorm.Expr("view_count + ?", 1))
return &asset, nil
}
func (s *AssetService) ListAssets(req *ListAssetsRequest) ([]models.Asset, int64, error) {
query := s.db.Model(&models.Asset{})
if req.DramaID != nil {
var dramaID uint64
dramaID, _ = strconv.ParseUint(*req.DramaID, 10, 32)
query = query.Where("drama_id = ?", uint(dramaID))
}
if req.EpisodeID != nil {
query = query.Where("episode_id = ?", *req.EpisodeID)
}
if req.StoryboardID != nil {
query = query.Where("storyboard_id = ?", *req.StoryboardID)
}
if req.Type != nil {
query = query.Where("type = ?", *req.Type)
}
if req.Category != "" {
query = query.Where("category = ?", req.Category)
}
if req.IsFavorite != nil {
query = query.Where("is_favorite = ?", *req.IsFavorite)
}
if req.Search != "" {
searchTerm := "%" + strings.ToLower(req.Search) + "%"
query = query.Where("LOWER(name) LIKE ? OR LOWER(description) LIKE ?", searchTerm, searchTerm)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
var assets []models.Asset
offset := (req.Page - 1) * req.PageSize
if err := query.Order("created_at DESC").
Offset(offset).Limit(req.PageSize).Find(&assets).Error; err != nil {
return nil, 0, err
}
return assets, total, nil
}
func (s *AssetService) DeleteAsset(assetID uint) error {
result := s.db.Where("id = ?", assetID).Delete(&models.Asset{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("asset not found")
}
return nil
}
func (s *AssetService) ImportFromImageGen(imageGenID uint) (*models.Asset, error) {
var imageGen models.ImageGeneration
if err := s.db.Where("id = ? ", imageGenID).First(&imageGen).Error; err != nil {
return nil, fmt.Errorf("image generation not found")
}
if imageGen.Status != models.ImageStatusCompleted || imageGen.ImageURL == nil {
return nil, fmt.Errorf("image is not ready")
}
dramaID := imageGen.DramaID
asset := &models.Asset{
Name: fmt.Sprintf("Image_%d", imageGen.ID),
Type: models.AssetTypeImage,
URL: *imageGen.ImageURL,
DramaID: &dramaID,
ImageGenID: &imageGenID,
Width: imageGen.Width,
Height: imageGen.Height,
}
if err := s.db.Create(asset).Error; err != nil {
return nil, fmt.Errorf("failed to create asset: %w", err)
}
return asset, nil
}
func (s *AssetService) ImportFromVideoGen(videoGenID uint) (*models.Asset, error) {
var videoGen models.VideoGeneration
if err := s.db.Preload("Storyboard.Episode").Where("id = ? ", videoGenID).First(&videoGen).Error; err != nil {
return nil, fmt.Errorf("video generation not found")
}
if videoGen.Status != models.VideoStatusCompleted || videoGen.VideoURL == nil {
return nil, fmt.Errorf("video is not ready")
}
dramaID := videoGen.DramaID
var episodeID *uint
var storyboardNum *int
if videoGen.Storyboard != nil {
episodeID = &videoGen.Storyboard.Episode.ID
storyboardNum = &videoGen.Storyboard.StoryboardNumber
}
asset := &models.Asset{
Name: fmt.Sprintf("Video_%d", videoGen.ID),
Type: models.AssetTypeVideo,
URL: *videoGen.VideoURL,
DramaID: &dramaID,
EpisodeID: episodeID,
StoryboardID: videoGen.StoryboardID,
StoryboardNum: storyboardNum,
VideoGenID: &videoGenID,
Duration: videoGen.Duration,
Width: videoGen.Width,
Height: videoGen.Height,
}
if videoGen.FirstFrameURL != nil {
asset.ThumbnailURL = videoGen.FirstFrameURL
}
if err := s.db.Create(asset).Error; err != nil {
return nil, fmt.Errorf("failed to create asset: %w", err)
}
return asset, nil
}

View File

@@ -0,0 +1,473 @@
package services
import (
"errors"
"fmt"
"time"
models "github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
type CharacterLibraryService struct {
db *gorm.DB
log *logger.Logger
}
func NewCharacterLibraryService(db *gorm.DB, log *logger.Logger) *CharacterLibraryService {
return &CharacterLibraryService{
db: db,
log: log,
}
}
type CreateLibraryItemRequest struct {
Name string `json:"name" binding:"required,min=1,max=100"`
Category *string `json:"category"`
ImageURL string `json:"image_url" binding:"required"`
Description *string `json:"description"`
Tags *string `json:"tags"`
SourceType string `json:"source_type"`
}
type CharacterLibraryQuery struct {
Page int `form:"page,default=1"`
PageSize int `form:"page_size,default=20"`
Category string `form:"category"`
SourceType string `form:"source_type"`
Keyword string `form:"keyword"`
}
// ListLibraryItems 获取用户角色库列表
func (s *CharacterLibraryService) ListLibraryItems(query *CharacterLibraryQuery) ([]models.CharacterLibrary, int64, error) {
var items []models.CharacterLibrary
var total int64
db := s.db.Model(&models.CharacterLibrary{})
// 筛选条件
if query.Category != "" {
db = db.Where("category = ?", query.Category)
}
if query.SourceType != "" {
db = db.Where("source_type = ?", query.SourceType)
}
if query.Keyword != "" {
db = db.Where("name LIKE ? OR description LIKE ?", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
}
// 获取总数
if err := db.Count(&total).Error; err != nil {
s.log.Errorw("Failed to count character library", "error", err)
return nil, 0, err
}
// 分页查询
offset := (query.Page - 1) * query.PageSize
err := db.Order("created_at DESC").
Offset(offset).
Limit(query.PageSize).
Find(&items).Error
if err != nil {
s.log.Errorw("Failed to list character library", "error", err)
return nil, 0, err
}
return items, total, nil
}
// CreateLibraryItem 添加到角色库
func (s *CharacterLibraryService) CreateLibraryItem(req *CreateLibraryItemRequest) (*models.CharacterLibrary, error) {
sourceType := req.SourceType
if sourceType == "" {
sourceType = "generated"
}
item := &models.CharacterLibrary{
Name: req.Name,
Category: req.Category,
ImageURL: req.ImageURL,
Description: req.Description,
Tags: req.Tags,
SourceType: sourceType,
}
if err := s.db.Create(item).Error; err != nil {
s.log.Errorw("Failed to create library item", "error", err)
return nil, err
}
s.log.Infow("Library item created", "item_id", item.ID)
return item, nil
}
// GetLibraryItem 获取角色库项
func (s *CharacterLibraryService) GetLibraryItem(itemID string) (*models.CharacterLibrary, error) {
var item models.CharacterLibrary
err := s.db.Where("id = ? ", itemID).First(&item).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("library item not found")
}
s.log.Errorw("Failed to get library item", "error", err)
return nil, err
}
return &item, nil
}
// DeleteLibraryItem 删除角色库项
func (s *CharacterLibraryService) DeleteLibraryItem(itemID string) error {
result := s.db.Where("id = ? ", itemID).Delete(&models.CharacterLibrary{})
if result.Error != nil {
s.log.Errorw("Failed to delete library item", "error", result.Error)
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("library item not found")
}
s.log.Infow("Library item deleted", "item_id", itemID)
return nil
}
// ApplyLibraryItemToCharacter 将角色库形象应用到角色
func (s *CharacterLibraryService) ApplyLibraryItemToCharacter(characterID string, libraryItemID string) error {
// 验证角色库项存在且属于该用户
var libraryItem models.CharacterLibrary
if err := s.db.Where("id = ? ", libraryItemID).First(&libraryItem).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("library item not found")
}
return err
}
// 查找角色
var character models.Character
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("character not found")
}
return err
}
// 查询Drama验证权限
var drama models.Drama
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("unauthorized")
}
return err
}
// 更新角色的image_url
if err := s.db.Model(&character).Update("image_url", libraryItem.ImageURL).Error; err != nil {
s.log.Errorw("Failed to update character image", "error", err)
return err
}
s.log.Infow("Library item applied to character", "character_id", characterID, "library_item_id", libraryItemID)
return nil
}
// UploadCharacterImage 上传角色图片
func (s *CharacterLibraryService) UploadCharacterImage(characterID string, imageURL string) error {
// 查找角色
var character models.Character
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("character not found")
}
return err
}
// 查询Drama验证权限
var drama models.Drama
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("unauthorized")
}
return err
}
// 更新图片URL
if err := s.db.Model(&character).Update("image_url", imageURL).Error; err != nil {
s.log.Errorw("Failed to update character image", "error", err)
return err
}
s.log.Infow("Character image uploaded", "character_id", characterID)
return nil
}
// AddCharacterToLibrary 将角色添加到角色库
func (s *CharacterLibraryService) AddCharacterToLibrary(characterID string, category *string) (*models.CharacterLibrary, error) {
// 查找角色
var character models.Character
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("character not found")
}
return nil, err
}
// 查询Drama验证权限
var drama models.Drama
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("unauthorized")
}
return nil, err
}
// 检查是否有图片
if character.ImageURL == nil || *character.ImageURL == "" {
return nil, fmt.Errorf("角色还没有形象图片")
}
// 创建角色库项
charLibrary := &models.CharacterLibrary{
Name: character.Name,
ImageURL: *character.ImageURL,
Description: character.Description,
SourceType: "character",
}
if err := s.db.Create(charLibrary).Error; err != nil {
s.log.Errorw("Failed to add character to library", "error", err)
return nil, err
}
s.log.Infow("Character added to library", "character_id", characterID, "library_item_id", charLibrary.ID)
return charLibrary, nil
}
// DeleteCharacter 删除单个角色
func (s *CharacterLibraryService) DeleteCharacter(characterID uint) error {
// 查找角色
var character models.Character
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("character not found")
}
return err
}
// 验证权限检查角色所属的drama是否属于当前用户
var drama models.Drama
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("unauthorized")
}
return err
}
// 删除角色
if err := s.db.Delete(&character).Error; err != nil {
s.log.Errorw("Failed to delete character", "error", err, "id", characterID)
return err
}
s.log.Infow("Character deleted", "id", characterID)
return nil
}
// GenerateCharacterImage AI生成角色形象
func (s *CharacterLibraryService) GenerateCharacterImage(characterID string, imageService *ImageGenerationService, modelName string) (*models.ImageGeneration, error) {
// 查找角色
var character models.Character
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("character not found")
}
return nil, err
}
// 查询Drama验证权限
var drama models.Drama
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("unauthorized")
}
return nil, err
}
// 构建生成提示词 - 使用详细的外貌描述,添加干净背景要求
prompt := ""
// 优先使用appearance字段它包含了最详细的外貌描述
if character.Appearance != nil && *character.Appearance != "" {
prompt = *character.Appearance
} else if character.Description != nil && *character.Description != "" {
prompt = *character.Description
} else {
prompt = character.Name
}
// 添加角色画像和风格要求
prompt += ", character portrait, full body or upper body shot"
// 添加干净背景要求 - 确保背景简洁不干扰主体
prompt += ", simple clean background, plain solid color background, white or light gray background"
prompt += ", studio lighting, professional photography"
// 添加质量和风格要求
prompt += ", high quality, detailed, anime style, character design"
prompt += ", no complex background, no scenery, focus on character"
// 调用图片生成服务
dramaIDStr := fmt.Sprintf("%d", character.DramaID)
imageType := "character"
req := &GenerateImageRequest{
DramaID: dramaIDStr,
CharacterID: &character.ID,
ImageType: imageType,
Prompt: prompt,
Provider: "openai", // 或从配置读取
Model: modelName, // 使用用户指定的模型
Size: "2560x1440", // 3,686,400像素满足API最低要求16:9比例
Quality: "standard",
}
imageGen, err := imageService.GenerateImage(req)
if err != nil {
s.log.Errorw("Failed to generate character image", "error", err)
return nil, fmt.Errorf("图片生成失败: %w", err)
}
// 异步处理在后台监听图片生成完成然后更新角色image_url
go s.waitAndUpdateCharacterImage(character.ID, imageGen.ID)
// 立即返回ImageGeneration对象让前端可以轮询状态
s.log.Infow("Character image generation started", "character_id", characterID, "image_gen_id", imageGen.ID)
return imageGen, nil
}
// waitAndUpdateCharacterImage 后台异步等待图片生成完成并更新角色image_url
func (s *CharacterLibraryService) waitAndUpdateCharacterImage(characterID uint, imageGenID uint) {
maxAttempts := 60
pollInterval := 5 * time.Second
for i := 0; i < maxAttempts; i++ {
time.Sleep(pollInterval)
// 查询图片生成状态
var imageGen models.ImageGeneration
if err := s.db.First(&imageGen, imageGenID).Error; err != nil {
s.log.Errorw("Failed to query image generation status", "error", err, "image_gen_id", imageGenID)
continue
}
// 检查是否完成
if imageGen.Status == models.ImageStatusCompleted && imageGen.ImageURL != nil && *imageGen.ImageURL != "" {
// 更新角色的image_url
if err := s.db.Model(&models.Character{}).Where("id = ?", characterID).Update("image_url", *imageGen.ImageURL).Error; err != nil {
s.log.Errorw("Failed to update character image_url", "error", err, "character_id", characterID)
return
}
s.log.Infow("Character image updated successfully", "character_id", characterID, "image_url", *imageGen.ImageURL)
return
}
// 检查是否失败
if imageGen.Status == models.ImageStatusFailed {
s.log.Errorw("Character image generation failed", "character_id", characterID, "image_gen_id", imageGenID, "error", imageGen.ErrorMsg)
return
}
}
s.log.Warnw("Character image generation timeout", "character_id", characterID, "image_gen_id", imageGenID)
}
// UpdateCharacter 更新角色信息
func (s *CharacterLibraryService) UpdateCharacter(characterID string, req interface{}) error {
// 查找角色
var character models.Character
if err := s.db.Where("id = ?", characterID).First(&character).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("character not found")
}
return err
}
// 验证权限查询角色所属的drama是否属于该用户
var drama models.Drama
if err := s.db.Where("id = ? ", character.DramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("unauthorized")
}
return err
}
// 构建更新数据
updates := make(map[string]interface{})
// 使用类型断言获取请求数据
if reqMap, ok := req.(*struct {
Name *string `json:"name"`
Appearance *string `json:"appearance"`
Personality *string `json:"personality"`
Description *string `json:"description"`
}); ok {
if reqMap.Name != nil && *reqMap.Name != "" {
updates["name"] = *reqMap.Name
}
if reqMap.Appearance != nil {
updates["appearance"] = *reqMap.Appearance
}
if reqMap.Personality != nil {
updates["personality"] = *reqMap.Personality
}
if reqMap.Description != nil {
updates["description"] = *reqMap.Description
}
}
if len(updates) == 0 {
return errors.New("no fields to update")
}
// 更新角色信息
if err := s.db.Model(&character).Updates(updates).Error; err != nil {
s.log.Errorw("Failed to update character", "error", err, "character_id", characterID)
return err
}
s.log.Infow("Character updated", "character_id", characterID)
return nil
}
// BatchGenerateCharacterImages 批量生成角色图片(并发执行)
func (s *CharacterLibraryService) BatchGenerateCharacterImages(characterIDs []string, imageService *ImageGenerationService, modelName string) {
s.log.Infow("Starting batch character image generation",
"count", len(characterIDs),
"model", modelName)
// 使用 goroutine 并发生成所有角色图片
for _, characterID := range characterIDs {
// 为每个角色启动单独的 goroutine
go func(charID string) {
imageGen, err := s.GenerateCharacterImage(charID, imageService, modelName)
if err != nil {
s.log.Errorw("Failed to generate character image in batch",
"character_id", charID,
"error", err)
return
}
s.log.Infow("Character image generated in batch",
"character_id", charID,
"image_gen_id", imageGen.ID)
}(characterID)
}
s.log.Infow("Batch character image generation tasks submitted",
"total", len(characterIDs))
}

View File

@@ -0,0 +1,630 @@
package services
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
type DramaService struct {
db *gorm.DB
log *logger.Logger
}
func NewDramaService(db *gorm.DB, log *logger.Logger) *DramaService {
return &DramaService{
db: db,
log: log,
}
}
type CreateDramaRequest struct {
Title string `json:"title" binding:"required,min=1,max=100"`
Description string `json:"description"`
Genre string `json:"genre"`
Tags string `json:"tags"`
}
type UpdateDramaRequest struct {
Title string `json:"title" binding:"omitempty,min=1,max=100"`
Description string `json:"description"`
Genre string `json:"genre"`
Tags string `json:"tags"`
Status string `json:"status" binding:"omitempty,oneof=draft planning production completed archived"`
}
type DramaListQuery struct {
Page int `form:"page,default=1"`
PageSize int `form:"page_size,default=20"`
Status string `form:"status"`
Genre string `form:"genre"`
Keyword string `form:"keyword"`
}
func (s *DramaService) CreateDrama(req *CreateDramaRequest) (*models.Drama, error) {
drama := &models.Drama{
Title: req.Title,
Status: "draft",
}
if req.Description != "" {
drama.Description = &req.Description
}
if req.Genre != "" {
drama.Genre = &req.Genre
}
if err := s.db.Create(drama).Error; err != nil {
s.log.Errorw("Failed to create drama", "error", err)
return nil, err
}
s.log.Infow("Drama created", "drama_id", drama.ID)
return drama, nil
}
func (s *DramaService) GetDrama(dramaID string) (*models.Drama, error) {
var drama models.Drama
err := s.db.Where("id = ? ", dramaID).
Preload("Characters"). // 加载Drama级别的角色
Preload("Scenes"). // 加载Drama级别的场景
Preload("Episodes.Characters"). // 加载每个章节关联的角色
Preload("Episodes.Scenes"). // 加载每个章节关联的场景
Preload("Episodes.Storyboards", func(db *gorm.DB) *gorm.DB {
return db.Order("storyboards.storyboard_number ASC")
}).
First(&drama).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("drama not found")
}
s.log.Errorw("Failed to get drama", "error", err)
return nil, err
}
// 统计每个剧集的时长(基于场景时长之和)
for i := range drama.Episodes {
totalDuration := 0
for _, scene := range drama.Episodes[i].Storyboards {
totalDuration += scene.Duration
}
// 更新剧集时长(秒转分钟,向上取整)
durationMinutes := (totalDuration + 59) / 60
drama.Episodes[i].Duration = durationMinutes
// 如果数据库中的时长与计算的不一致,更新数据库
if drama.Episodes[i].Duration != durationMinutes {
s.db.Model(&models.Episode{}).Where("id = ?", drama.Episodes[i].ID).Update("duration", durationMinutes)
}
// 查询角色的图片生成状态
for j := range drama.Episodes[i].Characters {
var imageGen models.ImageGeneration
err := s.db.Where("character_id = ? AND (status = ? OR status = ?)",
drama.Episodes[i].Characters[j].ID, "pending", "processing").
Order("created_at DESC").
First(&imageGen).Error
if err == nil {
// 找到生成中的记录,设置状态
statusStr := string(imageGen.Status)
drama.Episodes[i].Characters[j].ImageGenerationStatus = &statusStr
if imageGen.ErrorMsg != nil {
drama.Episodes[i].Characters[j].ImageGenerationError = imageGen.ErrorMsg
}
} else if errors.Is(err, gorm.ErrRecordNotFound) {
// 检查是否有失败的记录
err := s.db.Where("character_id = ? AND status = ?",
drama.Episodes[i].Characters[j].ID, "failed").
Order("created_at DESC").
First(&imageGen).Error
if err == nil {
statusStr := string(imageGen.Status)
drama.Episodes[i].Characters[j].ImageGenerationStatus = &statusStr
if imageGen.ErrorMsg != nil {
drama.Episodes[i].Characters[j].ImageGenerationError = imageGen.ErrorMsg
}
}
}
}
// 查询场景的图片生成状态
for j := range drama.Episodes[i].Scenes {
var imageGen models.ImageGeneration
err := s.db.Where("scene_id = ? AND (status = ? OR status = ?)",
drama.Episodes[i].Scenes[j].ID, "pending", "processing").
Order("created_at DESC").
First(&imageGen).Error
if err == nil {
// 找到生成中的记录,设置状态
statusStr := string(imageGen.Status)
drama.Episodes[i].Scenes[j].ImageGenerationStatus = &statusStr
if imageGen.ErrorMsg != nil {
drama.Episodes[i].Scenes[j].ImageGenerationError = imageGen.ErrorMsg
}
} else if errors.Is(err, gorm.ErrRecordNotFound) {
// 检查是否有失败的记录
err := s.db.Where("scene_id = ? AND status = ?",
drama.Episodes[i].Scenes[j].ID, "failed").
Order("created_at DESC").
First(&imageGen).Error
if err == nil {
statusStr := string(imageGen.Status)
drama.Episodes[i].Scenes[j].ImageGenerationStatus = &statusStr
if imageGen.ErrorMsg != nil {
drama.Episodes[i].Scenes[j].ImageGenerationError = imageGen.ErrorMsg
}
}
}
}
}
// 整合所有剧集的场景到Drama级别的Scenes字段
sceneMap := make(map[uint]*models.Scene) // 用于去重
for i := range drama.Episodes {
for j := range drama.Episodes[i].Scenes {
scene := &drama.Episodes[i].Scenes[j]
sceneMap[scene.ID] = scene
}
}
// 将整合的场景添加到drama.Scenes
drama.Scenes = make([]models.Scene, 0, len(sceneMap))
for _, scene := range sceneMap {
drama.Scenes = append(drama.Scenes, *scene)
}
return &drama, nil
}
func (s *DramaService) ListDramas(query *DramaListQuery) ([]models.Drama, int64, error) {
var dramas []models.Drama
var total int64
db := s.db.Model(&models.Drama{})
if query.Status != "" {
db = db.Where("status = ?", query.Status)
}
if query.Genre != "" {
db = db.Where("genre = ?", query.Genre)
}
if query.Keyword != "" {
db = db.Where("title LIKE ? OR description LIKE ?", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
}
if err := db.Count(&total).Error; err != nil {
s.log.Errorw("Failed to count dramas", "error", err)
return nil, 0, err
}
offset := (query.Page - 1) * query.PageSize
err := db.Order("updated_at DESC").
Offset(offset).
Limit(query.PageSize).
Preload("Episodes.Storyboards", func(db *gorm.DB) *gorm.DB {
return db.Order("storyboards.storyboard_number ASC")
}).
Find(&dramas).Error
if err != nil {
s.log.Errorw("Failed to list dramas", "error", err)
return nil, 0, err
}
// 统计每个剧本的每个剧集的时长(基于场景时长之和)
for i := range dramas {
for j := range dramas[i].Episodes {
totalDuration := 0
for _, scene := range dramas[i].Episodes[j].Storyboards {
totalDuration += scene.Duration
}
// 更新剧集时长(秒转分钟,向上取整)
durationMinutes := (totalDuration + 59) / 60
dramas[i].Episodes[j].Duration = durationMinutes
}
}
return dramas, total, nil
}
func (s *DramaService) UpdateDrama(dramaID string, req *UpdateDramaRequest) (*models.Drama, error) {
var drama models.Drama
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("drama not found")
}
return nil, err
}
updates := make(map[string]interface{})
if req.Title != "" {
updates["title"] = req.Title
}
if req.Description != "" {
updates["description"] = req.Description
}
if req.Genre != "" {
updates["genre"] = req.Genre
}
if req.Tags != "" {
updates["tags"] = req.Tags
}
if req.Status != "" {
updates["status"] = req.Status
}
updates["updated_at"] = time.Now()
if err := s.db.Model(&drama).Updates(updates).Error; err != nil {
s.log.Errorw("Failed to update drama", "error", err)
return nil, err
}
s.log.Infow("Drama updated", "drama_id", dramaID)
return &drama, nil
}
func (s *DramaService) DeleteDrama(dramaID string) error {
result := s.db.Where("id = ? ", dramaID).Delete(&models.Drama{})
if result.Error != nil {
s.log.Errorw("Failed to delete drama", "error", result.Error)
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("drama not found")
}
s.log.Infow("Drama deleted", "drama_id", dramaID)
return nil
}
func (s *DramaService) GetDramaStats() (map[string]interface{}, error) {
var total int64
var byStatus []struct {
Status string
Count int64
}
if err := s.db.Model(&models.Drama{}).Count(&total).Error; err != nil {
return nil, err
}
if err := s.db.Model(&models.Drama{}).
Select("status, count(*) as count").
Group("status").
Scan(&byStatus).Error; err != nil {
return nil, err
}
stats := map[string]interface{}{
"total": total,
"by_status": byStatus,
}
return stats, nil
}
type SaveOutlineRequest struct {
Title string `json:"title" binding:"required"`
Summary string `json:"summary" binding:"required"`
Genre string `json:"genre"`
Tags []string `json:"tags"`
}
type SaveCharactersRequest struct {
Characters []models.Character `json:"characters" binding:"required"`
EpisodeID *uint `json:"episode_id"` // 可选:如果提供则关联到指定章节
}
type SaveProgressRequest struct {
CurrentStep string `json:"current_step" binding:"required"`
StepData map[string]interface{} `json:"step_data"`
}
type SaveEpisodesRequest struct {
Episodes []models.Episode `json:"episodes" binding:"required"`
}
func (s *DramaService) SaveOutline(dramaID string, req *SaveOutlineRequest) error {
var drama models.Drama
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("drama not found")
}
return err
}
updates := map[string]interface{}{
"title": req.Title,
"description": req.Summary,
"updated_at": time.Now(),
}
if req.Genre != "" {
updates["genre"] = req.Genre
}
if len(req.Tags) > 0 {
tagsJSON, err := json.Marshal(req.Tags)
if err != nil {
s.log.Errorw("Failed to marshal tags", "error", err)
return err
}
updates["tags"] = tagsJSON
}
if err := s.db.Model(&drama).Updates(updates).Error; err != nil {
s.log.Errorw("Failed to save outline", "error", err)
return err
}
s.log.Infow("Outline saved", "drama_id", dramaID)
return nil
}
func (s *DramaService) GetCharacters(dramaID string, episodeID *string) ([]models.Character, error) {
var drama models.Drama
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("drama not found")
}
return nil, err
}
var characters []models.Character
// 如果指定了episodeID只获取该章节关联的角色
if episodeID != nil {
var episode models.Episode
if err := s.db.Preload("Characters").Where("id = ? AND drama_id = ?", *episodeID, dramaID).First(&episode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("episode not found")
}
return nil, err
}
characters = episode.Characters
} else {
// 如果没有指定episodeID获取项目的所有角色
if err := s.db.Where("drama_id = ?", dramaID).Find(&characters).Error; err != nil {
s.log.Errorw("Failed to get characters", "error", err)
return nil, err
}
}
// 查询每个角色的图片生成任务状态
for i := range characters {
// 查询该角色最新的图片生成任务
var imageGen models.ImageGeneration
err := s.db.Where("character_id = ?", characters[i].ID).
Order("created_at DESC").
First(&imageGen).Error
if err == nil {
// 如果有进行中的任务,填充状态信息
if imageGen.Status == models.ImageStatusPending || imageGen.Status == models.ImageStatusProcessing {
statusStr := string(imageGen.Status)
characters[i].ImageGenerationStatus = &statusStr
} else if imageGen.Status == models.ImageStatusFailed {
statusStr := "failed"
characters[i].ImageGenerationStatus = &statusStr
if imageGen.ErrorMsg != nil {
characters[i].ImageGenerationError = imageGen.ErrorMsg
}
}
}
}
return characters, nil
}
func (s *DramaService) SaveCharacters(dramaID string, req *SaveCharactersRequest) error {
// 转换dramaID
id, err := strconv.ParseUint(dramaID, 10, 32)
if err != nil {
return fmt.Errorf("invalid drama ID")
}
dramaIDUint := uint(id)
var drama models.Drama
if err := s.db.Where("id = ? ", dramaIDUint).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("drama not found")
}
return err
}
// 如果指定了EpisodeID验证章节存在性
if req.EpisodeID != nil {
var episode models.Episode
if err := s.db.Where("id = ? AND drama_id = ?", *req.EpisodeID, dramaIDUint).First(&episode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("episode not found")
}
return err
}
}
// 获取该项目已存在的所有角色
var existingCharacters []models.Character
if err := s.db.Where("drama_id = ?", dramaIDUint).Find(&existingCharacters).Error; err != nil {
s.log.Errorw("Failed to get existing characters", "error", err)
return err
}
// 创建角色名称到角色的映射
existingCharMap := make(map[string]*models.Character)
for i := range existingCharacters {
existingCharMap[existingCharacters[i].Name] = &existingCharacters[i]
}
// 收集需要关联到章节的角色ID
var characterIDs []uint
// 创建新角色或复用已有角色
for _, char := range req.Characters {
if existingChar, exists := existingCharMap[char.Name]; exists {
// 角色已存在,直接复用
s.log.Infow("Character already exists, reusing", "name", char.Name, "character_id", existingChar.ID)
characterIDs = append(characterIDs, existingChar.ID)
continue
}
// 角色不存在,创建新角色
character := models.Character{
DramaID: dramaIDUint,
Name: char.Name,
Role: char.Role,
Description: char.Description,
Personality: char.Personality,
Appearance: char.Appearance,
}
if err := s.db.Create(&character).Error; err != nil {
s.log.Errorw("Failed to create character", "error", err, "name", char.Name)
continue
}
s.log.Infow("New character created", "character_id", character.ID, "name", char.Name)
characterIDs = append(characterIDs, character.ID)
}
// 如果指定了EpisodeID建立角色与章节的关联
if req.EpisodeID != nil && len(characterIDs) > 0 {
var episode models.Episode
if err := s.db.First(&episode, *req.EpisodeID).Error; err != nil {
return err
}
// 获取角色对象
var characters []models.Character
if err := s.db.Where("id IN ?", characterIDs).Find(&characters).Error; err != nil {
s.log.Errorw("Failed to get characters", "error", err)
return err
}
// 使用GORM的Association API建立多对多关系会自动去重
if err := s.db.Model(&episode).Association("Characters").Append(&characters); err != nil {
s.log.Errorw("Failed to associate characters with episode", "error", err)
return err
}
s.log.Infow("Characters associated with episode", "episode_id", *req.EpisodeID, "character_count", len(characterIDs))
}
if err := s.db.Model(&drama).Update("updated_at", time.Now()).Error; err != nil {
s.log.Errorw("Failed to update drama timestamp", "error", err)
}
s.log.Infow("Characters saved", "drama_id", dramaID, "count", len(req.Characters))
return nil
}
func (s *DramaService) SaveEpisodes(dramaID string, req *SaveEpisodesRequest) error {
// 转换dramaID
id, err := strconv.ParseUint(dramaID, 10, 32)
if err != nil {
return fmt.Errorf("invalid drama ID")
}
dramaIDUint := uint(id)
var drama models.Drama
if err := s.db.Where("id = ? ", dramaIDUint).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("drama not found")
}
return err
}
// 删除旧剧集
if err := s.db.Where("drama_id = ?", dramaIDUint).Delete(&models.Episode{}).Error; err != nil {
s.log.Errorw("Failed to delete old episodes", "error", err)
return err
}
// 创建新剧集(不包含场景,场景由后续步骤生成)
for _, ep := range req.Episodes {
episode := models.Episode{
DramaID: dramaIDUint,
EpisodeNum: ep.EpisodeNum,
Title: ep.Title,
Description: ep.Description,
ScriptContent: ep.ScriptContent,
Duration: ep.Duration,
Status: "draft",
}
if err := s.db.Create(&episode).Error; err != nil {
s.log.Errorw("Failed to create episode", "error", err, "episode", ep.EpisodeNum)
continue
}
}
if err := s.db.Model(&drama).Update("updated_at", time.Now()).Error; err != nil {
s.log.Errorw("Failed to update drama timestamp", "error", err)
}
s.log.Infow("Episodes saved", "drama_id", dramaID, "count", len(req.Episodes))
return nil
}
func (s *DramaService) SaveProgress(dramaID string, req *SaveProgressRequest) error {
var drama models.Drama
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("drama not found")
}
return err
}
// 构建metadata对象
metadata := make(map[string]interface{})
// 保留现有metadata
if drama.Metadata != nil {
if err := json.Unmarshal(drama.Metadata, &metadata); err != nil {
s.log.Warnw("Failed to unmarshal existing metadata", "error", err)
}
}
// 更新progress信息
metadata["current_step"] = req.CurrentStep
if req.StepData != nil {
metadata["step_data"] = req.StepData
}
// 序列化metadata
metadataJSON, err := json.Marshal(metadata)
if err != nil {
s.log.Errorw("Failed to marshal metadata", "error", err)
return err
}
updates := map[string]interface{}{
"metadata": metadataJSON,
"updated_at": time.Now(),
}
if err := s.db.Model(&drama).Updates(updates).Error; err != nil {
s.log.Errorw("Failed to save progress", "error", err)
return err
}
s.log.Infow("Progress saved", "drama_id", dramaID, "step", req.CurrentStep)
return nil
}

View File

@@ -0,0 +1,428 @@
package services
import (
"fmt"
"strings"
"github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
// FramePromptService 处理帧提示词生成
type FramePromptService struct {
db *gorm.DB
aiService *AIService
log *logger.Logger
}
// NewFramePromptService 创建帧提示词服务
func NewFramePromptService(db *gorm.DB, log *logger.Logger) *FramePromptService {
return &FramePromptService{
db: db,
aiService: NewAIService(db, log),
log: log,
}
}
// FrameType 帧类型
type FrameType string
const (
FrameTypeFirst FrameType = "first" // 首帧
FrameTypeKey FrameType = "key" // 关键帧
FrameTypeLast FrameType = "last" // 尾帧
FrameTypePanel FrameType = "panel" // 分镜板3格组合
FrameTypeAction FrameType = "action" // 动作序列5格
)
// GenerateFramePromptRequest 生成帧提示词请求
type GenerateFramePromptRequest struct {
StoryboardID string `json:"storyboard_id"`
FrameType FrameType `json:"frame_type"`
// 可选参数
PanelCount int `json:"panel_count,omitempty"` // 分镜板格数默认3
}
// FramePromptResponse 帧提示词响应
type FramePromptResponse struct {
FrameType FrameType `json:"frame_type"`
SingleFrame *SingleFramePrompt `json:"single_frame,omitempty"` // 单帧提示词
MultiFrame *MultiFramePrompt `json:"multi_frame,omitempty"` // 多帧提示词
}
// SingleFramePrompt 单帧提示词
type SingleFramePrompt struct {
Prompt string `json:"prompt"`
Description string `json:"description"`
}
// MultiFramePrompt 多帧提示词
type MultiFramePrompt struct {
Layout string `json:"layout"` // horizontal_3, grid_2x2 等
Frames []SingleFramePrompt `json:"frames"`
}
// GenerateFramePrompt 生成指定类型的帧提示词并保存到frame_prompts表
func (s *FramePromptService) GenerateFramePrompt(req GenerateFramePromptRequest) (*FramePromptResponse, error) {
// 查询分镜信息
var storyboard models.Storyboard
if err := s.db.Preload("Characters").First(&storyboard, req.StoryboardID).Error; err != nil {
return nil, fmt.Errorf("storyboard not found: %w", err)
}
// 获取场景信息
var scene *models.Scene
if storyboard.SceneID != nil {
scene = &models.Scene{}
if err := s.db.First(scene, *storyboard.SceneID).Error; err != nil {
s.log.Warnw("Scene not found", "scene_id", *storyboard.SceneID)
scene = nil
}
}
response := &FramePromptResponse{
FrameType: req.FrameType,
}
// 生成提示词
switch req.FrameType {
case FrameTypeFirst:
response.SingleFrame = s.generateFirstFrame(storyboard, scene)
// 保存单帧提示词
s.saveFramePrompt(req.StoryboardID, string(req.FrameType), response.SingleFrame.Prompt, response.SingleFrame.Description, "")
case FrameTypeKey:
response.SingleFrame = s.generateKeyFrame(storyboard, scene)
s.saveFramePrompt(req.StoryboardID, string(req.FrameType), response.SingleFrame.Prompt, response.SingleFrame.Description, "")
case FrameTypeLast:
response.SingleFrame = s.generateLastFrame(storyboard, scene)
s.saveFramePrompt(req.StoryboardID, string(req.FrameType), response.SingleFrame.Prompt, response.SingleFrame.Description, "")
case FrameTypePanel:
count := req.PanelCount
if count == 0 {
count = 3
}
response.MultiFrame = s.generatePanelFrames(storyboard, scene, count)
// 保存多帧提示词(合并为一条记录)
var prompts []string
for _, frame := range response.MultiFrame.Frames {
prompts = append(prompts, frame.Prompt)
}
combinedPrompt := strings.Join(prompts, "\n---\n")
s.saveFramePrompt(req.StoryboardID, string(req.FrameType), combinedPrompt, "分镜板组合提示词", response.MultiFrame.Layout)
case FrameTypeAction:
response.MultiFrame = s.generateActionSequence(storyboard, scene)
var prompts []string
for _, frame := range response.MultiFrame.Frames {
prompts = append(prompts, frame.Prompt)
}
combinedPrompt := strings.Join(prompts, "\n---\n")
s.saveFramePrompt(req.StoryboardID, string(req.FrameType), combinedPrompt, "动作序列组合提示词", response.MultiFrame.Layout)
default:
return nil, fmt.Errorf("unsupported frame type: %s", req.FrameType)
}
return response, nil
}
// saveFramePrompt 保存帧提示词到数据库
func (s *FramePromptService) saveFramePrompt(storyboardID, frameType, prompt, description, layout string) {
framePrompt := models.FramePrompt{
StoryboardID: uint(mustParseUint(storyboardID)),
FrameType: frameType,
Prompt: prompt,
}
if description != "" {
framePrompt.Description = &description
}
if layout != "" {
framePrompt.Layout = &layout
}
// 先删除同类型的旧记录(保持最新)
s.db.Where("storyboard_id = ? AND frame_type = ?", storyboardID, frameType).Delete(&models.FramePrompt{})
// 插入新记录
if err := s.db.Create(&framePrompt).Error; err != nil {
s.log.Warnw("Failed to save frame prompt", "error", err, "storyboard_id", storyboardID, "frame_type", frameType)
}
}
// mustParseUint 辅助函数
func mustParseUint(s string) uint64 {
var result uint64
fmt.Sscanf(s, "%d", &result)
return result
}
// generateFirstFrame 生成首帧提示词
func (s *FramePromptService) generateFirstFrame(sb models.Storyboard, scene *models.Scene) *SingleFramePrompt {
// 构建上下文信息
contextInfo := s.buildStoryboardContext(sb, scene)
// 构建AI提示词
systemPrompt := `你是一个专业的图像生成提示词专家。请根据提供的镜头信息生成适合用于AI图像生成的提示词。
重要:这是镜头的首帧 - 一个完全静态的画面,展示动作发生之前的初始状态。
要求:
1. 直接输出提示词,不要任何解释说明
2. 可以使用中文或英文,用逗号分隔关键词
3. 只描述静态视觉元素:场景环境、角色姿态、表情、氛围、光线
4. 不要包含任何动作动词(如:猛然、弹起、坐直、抓住等)
5. 描述角色处于动作发生前的状态(如:躺在床上、站立、坐着等静态姿态)
6. 适合动画风格anime style
示例格式:
Anime style, 城市公寓卧室, 凌晨, 昏暗房间, 床上, 年轻男子躺着, 表情平静, 闭眼睡眠, 柔和光线, 静谧氛围, 中景, 平视`
userPrompt := fmt.Sprintf(`镜头信息:
%s
请直接生成首帧的图像提示词,不要任何解释:`, contextInfo)
// 调用AI生成
prompt, err := s.aiService.GenerateText(userPrompt, systemPrompt)
if err != nil {
s.log.Warnw("AI generation failed, using fallback", "error", err)
// 降级方案:使用简单拼接
prompt = s.buildFallbackPrompt(sb, scene, "first frame, static shot")
}
// 如果AI返回空字符串使用降级方案
prompt = strings.TrimSpace(prompt)
if prompt == "" {
s.log.Warnw("AI returned empty prompt, using fallback", "storyboard_id", sb.ID)
prompt = s.buildFallbackPrompt(sb, scene, "first frame, static shot")
}
return &SingleFramePrompt{
Prompt: prompt,
Description: "镜头开始的静态画面,展示初始状态",
}
}
// generateKeyFrame 生成关键帧提示词
func (s *FramePromptService) generateKeyFrame(sb models.Storyboard, scene *models.Scene) *SingleFramePrompt {
// 构建上下文信息
contextInfo := s.buildStoryboardContext(sb, scene)
// 构建AI提示词
systemPrompt := `你是一个专业的图像生成提示词专家。请根据提供的镜头信息生成适合用于AI图像生成的提示词。
重要:这是镜头的关键帧 - 捕捉动作最激烈、最精彩的瞬间。
要求:
1. 直接输出提示词,不要任何解释说明
2. 可以使用中文或英文,用逗号分隔关键词
3. 重点描述动作的高潮瞬间:身体姿态、运动轨迹、力量感
4. 包含动态元素:动作模糊、速度线、冲击感
5. 强调表情和情绪的极致状态
6. 适合动画风格anime style
示例格式:
Anime style, 城市街道, 白天, 男子全力冲刺, 身体前倾, 动作模糊, 速度线, 汗水飞溅, 表情坚毅, 紧张氛围, 动态镜头, 中景`
userPrompt := fmt.Sprintf(`镜头信息:
%s
请直接生成关键帧的图像提示词,不要任何解释:`, contextInfo)
// 调用AI生成
prompt, err := s.aiService.GenerateText(userPrompt, systemPrompt)
if err != nil {
s.log.Warnw("AI generation failed, using fallback", "error", err)
prompt = s.buildFallbackPrompt(sb, scene, "key frame, dynamic action")
}
// 如果AI返回空字符串使用降级方案
prompt = strings.TrimSpace(prompt)
if prompt == "" {
s.log.Warnw("AI returned empty prompt, using fallback", "storyboard_id", sb.ID)
prompt = s.buildFallbackPrompt(sb, scene, "key frame, dynamic action")
}
return &SingleFramePrompt{
Prompt: prompt,
Description: "动作高潮瞬间,展示关键动作",
}
}
// generateLastFrame 生成尾帧提示词
func (s *FramePromptService) generateLastFrame(sb models.Storyboard, scene *models.Scene) *SingleFramePrompt {
// 构建上下文信息
contextInfo := s.buildStoryboardContext(sb, scene)
// 构建AI提示词
systemPrompt := `你是一个专业的图像生成提示词专家。请根据提供的镜头信息生成适合用于AI图像生成的提示词。
重要:这是镜头的尾帧 - 一个静态画面,展示动作结束后的最终状态和结果。
要求:
1. 直接输出提示词,不要任何解释说明
2. 可以使用中文或英文,用逗号分隔关键词
3. 只描述静态的最终状态:角色姿态、表情、环境变化
4. 不要包含动作过程,只展示动作的结果和余韵
5. 强调情绪的余波和氛围的沉淀
6. 适合动画风格anime style
示例格式:
Anime style, 房间内, 黄昏, 男子坐在椅子上, 身体放松, 表情疲惫, 长出一口气, 汗水滴落, 平静氛围, 静态镜头, 中景`
userPrompt := fmt.Sprintf(`镜头信息:
%s
请直接生成尾帧的图像提示词,不要任何解释:`, contextInfo)
// 调用AI生成
prompt, err := s.aiService.GenerateText(userPrompt, systemPrompt)
if err != nil {
s.log.Warnw("AI generation failed, using fallback", "error", err)
prompt = s.buildFallbackPrompt(sb, scene, "last frame, final state")
}
// 如果AI返回空字符串使用降级方案
prompt = strings.TrimSpace(prompt)
if prompt == "" {
s.log.Warnw("AI returned empty prompt, using fallback", "storyboard_id", sb.ID)
prompt = s.buildFallbackPrompt(sb, scene, "last frame, final state")
}
return &SingleFramePrompt{
Prompt: prompt,
Description: "镜头结束画面,展示最终状态和结果",
}
}
// generatePanelFrames 生成分镜板(多格组合)
func (s *FramePromptService) generatePanelFrames(sb models.Storyboard, scene *models.Scene, count int) *MultiFramePrompt {
layout := fmt.Sprintf("horizontal_%d", count)
frames := make([]SingleFramePrompt, count)
// 固定生成:首帧 -> 关键帧 -> 尾帧
if count == 3 {
frames[0] = *s.generateFirstFrame(sb, scene)
frames[0].Description = "第1格初始状态"
frames[1] = *s.generateKeyFrame(sb, scene)
frames[1].Description = "第2格动作高潮"
frames[2] = *s.generateLastFrame(sb, scene)
frames[2].Description = "第3格最终状态"
} else if count == 4 {
// 4格首帧 -> 中间帧1 -> 中间帧2 -> 尾帧
frames[0] = *s.generateFirstFrame(sb, scene)
frames[1] = *s.generateKeyFrame(sb, scene)
frames[2] = *s.generateKeyFrame(sb, scene)
frames[3] = *s.generateLastFrame(sb, scene)
}
return &MultiFramePrompt{
Layout: layout,
Frames: frames,
}
}
// generateActionSequence 生成动作序列5-8格
func (s *FramePromptService) generateActionSequence(sb models.Storyboard, scene *models.Scene) *MultiFramePrompt {
// 将动作分解为5个步骤
frames := make([]SingleFramePrompt, 5)
// 简化实现:均匀分布从首帧到尾帧
frames[0] = *s.generateFirstFrame(sb, scene)
frames[1] = *s.generateKeyFrame(sb, scene)
frames[2] = *s.generateKeyFrame(sb, scene)
frames[3] = *s.generateKeyFrame(sb, scene)
frames[4] = *s.generateLastFrame(sb, scene)
return &MultiFramePrompt{
Layout: "horizontal_5",
Frames: frames,
}
}
// buildStoryboardContext 构建镜头上下文信息
func (s *FramePromptService) buildStoryboardContext(sb models.Storyboard, scene *models.Scene) string {
var parts []string
// 镜头描述(最重要)
if sb.Description != nil && *sb.Description != "" {
parts = append(parts, fmt.Sprintf("镜头描述: %s", *sb.Description))
}
// 场景信息
if scene != nil {
parts = append(parts, fmt.Sprintf("场景: %s, %s", scene.Location, scene.Time))
} else if sb.Location != nil && sb.Time != nil {
parts = append(parts, fmt.Sprintf("场景: %s, %s", *sb.Location, *sb.Time))
}
// 角色
if len(sb.Characters) > 0 {
var charNames []string
for _, char := range sb.Characters {
charNames = append(charNames, char.Name)
}
parts = append(parts, fmt.Sprintf("角色: %s", strings.Join(charNames, ", ")))
}
// 动作
if sb.Action != nil && *sb.Action != "" {
parts = append(parts, fmt.Sprintf("动作: %s", *sb.Action))
}
// 结果
if sb.Result != nil && *sb.Result != "" {
parts = append(parts, fmt.Sprintf("结果: %s", *sb.Result))
}
// 对白
if sb.Dialogue != nil && *sb.Dialogue != "" {
parts = append(parts, fmt.Sprintf("对白: %s", *sb.Dialogue))
}
// 氛围
if sb.Atmosphere != nil && *sb.Atmosphere != "" {
parts = append(parts, fmt.Sprintf("氛围: %s", *sb.Atmosphere))
}
// 镜头参数
if sb.ShotType != nil {
parts = append(parts, fmt.Sprintf("景别: %s", *sb.ShotType))
}
if sb.Angle != nil {
parts = append(parts, fmt.Sprintf("角度: %s", *sb.Angle))
}
if sb.Movement != nil {
parts = append(parts, fmt.Sprintf("运镜: %s", *sb.Movement))
}
return strings.Join(parts, "\n")
}
// buildFallbackPrompt 构建降级提示词AI失败时使用
func (s *FramePromptService) buildFallbackPrompt(sb models.Storyboard, scene *models.Scene, suffix string) string {
var parts []string
// 场景
if scene != nil {
parts = append(parts, fmt.Sprintf("%s, %s", scene.Location, scene.Time))
}
// 角色
if len(sb.Characters) > 0 {
for _, char := range sb.Characters {
parts = append(parts, char.Name)
}
}
// 氛围
if sb.Atmosphere != nil {
parts = append(parts, *sb.Atmosphere)
}
parts = append(parts, "anime style", suffix)
return strings.Join(parts, ", ")
}

View File

@@ -0,0 +1,984 @@
package services
import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
models "github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/infrastructure/storage"
"github.com/drama-generator/backend/pkg/ai"
"github.com/drama-generator/backend/pkg/image"
"github.com/drama-generator/backend/pkg/logger"
"github.com/drama-generator/backend/pkg/utils"
"gorm.io/gorm"
)
type ImageGenerationService struct {
db *gorm.DB
aiService *AIService
transferService *ResourceTransferService
localStorage *storage.LocalStorage
log *logger.Logger
}
// truncateImageURL 截断图片 URL避免 base64 格式的 URL 占满日志
func truncateImageURL(url string) string {
if url == "" {
return ""
}
// 如果是 data URI 格式base64只显示前缀
if strings.HasPrefix(url, "data:") {
if len(url) > 50 {
return url[:50] + "...[base64 data]"
}
}
// 普通 URL 如果过长也截断
if len(url) > 100 {
return url[:100] + "..."
}
return url
}
func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService {
return &ImageGenerationService{
db: db,
aiService: NewAIService(db, log),
transferService: transferService,
localStorage: localStorage,
log: log,
}
}
// GetDB 获取数据库连接
func (s *ImageGenerationService) GetDB() *gorm.DB {
return s.db
}
type GenerateImageRequest struct {
StoryboardID *uint `json:"storyboard_id"`
DramaID string `json:"drama_id" binding:"required"`
SceneID *uint `json:"scene_id"`
CharacterID *uint `json:"character_id"`
ImageType string `json:"image_type"` // character, scene, storyboard
FrameType *string `json:"frame_type"` // first, key, last, panel, action
Prompt string `json:"prompt" binding:"required,min=5,max=2000"`
NegativePrompt *string `json:"negative_prompt"`
Provider string `json:"provider"`
Model string `json:"model"`
Size string `json:"size"`
Quality string `json:"quality"`
Style *string `json:"style"`
Steps *int `json:"steps"`
CfgScale *float64 `json:"cfg_scale"`
Seed *int64 `json:"seed"`
Width *int `json:"width"`
Height *int `json:"height"`
ReferenceImages []string `json:"reference_images"` // 参考图片URL列表
}
func (s *ImageGenerationService) GenerateImage(request *GenerateImageRequest) (*models.ImageGeneration, error) {
var drama models.Drama
if err := s.db.Where("id = ? ", request.DramaID).First(&drama).Error; err != nil {
return nil, fmt.Errorf("drama not found")
}
// 注意SceneID可能指向Scene或Storyboard表调用方已经做过权限验证这里不再重复验证
provider := request.Provider
if provider == "" {
provider = "openai"
}
// 序列化参考图片
var referenceImagesJSON []byte
if len(request.ReferenceImages) > 0 {
referenceImagesJSON, _ = json.Marshal(request.ReferenceImages)
}
// 转换DramaID
dramaIDParsed, err := strconv.ParseUint(request.DramaID, 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid drama ID")
}
// 设置默认图片类型
imageType := request.ImageType
if imageType == "" {
imageType = string(models.ImageTypeStoryboard)
}
imageGen := &models.ImageGeneration{
StoryboardID: request.StoryboardID,
DramaID: uint(dramaIDParsed),
SceneID: request.SceneID,
CharacterID: request.CharacterID,
ImageType: imageType,
FrameType: request.FrameType,
Provider: provider,
Prompt: request.Prompt,
NegPrompt: request.NegativePrompt,
Model: request.Model,
Size: request.Size,
ReferenceImages: referenceImagesJSON,
Quality: request.Quality,
Style: request.Style,
Steps: request.Steps,
CfgScale: request.CfgScale,
Seed: request.Seed,
Width: request.Width,
Height: request.Height,
Status: models.ImageStatusPending,
}
if err := s.db.Create(imageGen).Error; err != nil {
return nil, fmt.Errorf("failed to create record: %w", err)
}
go s.ProcessImageGeneration(imageGen.ID)
return imageGen, nil
}
func (s *ImageGenerationService) ProcessImageGeneration(imageGenID uint) {
var imageGen models.ImageGeneration
if err := s.db.First(&imageGen, imageGenID).Error; err != nil {
s.log.Errorw("Failed to load image generation", "error", err, "id", imageGenID)
return
}
s.db.Model(&imageGen).Update("status", models.ImageStatusProcessing)
// 如果关联了background同步更新background为generating状态
if imageGen.StoryboardID != nil {
if err := s.db.Model(&models.Scene{}).Where("id = ?", *imageGen.StoryboardID).Update("status", "generating").Error; err != nil {
s.log.Warnw("Failed to update background status to generating", "scene_id", *imageGen.StoryboardID, "error", err)
} else {
s.log.Infow("Background status updated to generating", "scene_id", *imageGen.StoryboardID)
}
}
client, err := s.getImageClientWithModel(imageGen.Provider, imageGen.Model)
if err != nil {
s.log.Errorw("Failed to get image client", "error", err, "provider", imageGen.Provider, "model", imageGen.Model)
s.updateImageGenError(imageGenID, err.Error())
return
}
// 解析参考图片
var referenceImages []string
if len(imageGen.ReferenceImages) > 0 {
if err := json.Unmarshal(imageGen.ReferenceImages, &referenceImages); err == nil {
s.log.Infow("Using reference images for generation",
"id", imageGenID,
"reference_count", len(referenceImages),
"references", referenceImages)
}
}
s.log.Infow("Starting image generation", "id", imageGenID, "prompt", imageGen.Prompt, "provider", imageGen.Provider)
var opts []image.ImageOption
if imageGen.NegPrompt != nil && *imageGen.NegPrompt != "" {
opts = append(opts, image.WithNegativePrompt(*imageGen.NegPrompt))
}
if imageGen.Size != "" {
opts = append(opts, image.WithSize(imageGen.Size))
}
if imageGen.Quality != "" {
opts = append(opts, image.WithQuality(imageGen.Quality))
}
if imageGen.Style != nil && *imageGen.Style != "" {
opts = append(opts, image.WithStyle(*imageGen.Style))
}
if imageGen.Steps != nil {
opts = append(opts, image.WithSteps(*imageGen.Steps))
}
if imageGen.CfgScale != nil {
opts = append(opts, image.WithCfgScale(*imageGen.CfgScale))
}
if imageGen.Seed != nil {
opts = append(opts, image.WithSeed(*imageGen.Seed))
}
if imageGen.Model != "" {
opts = append(opts, image.WithModel(imageGen.Model))
}
if imageGen.Width != nil && imageGen.Height != nil {
opts = append(opts, image.WithDimensions(*imageGen.Width, *imageGen.Height))
}
// 添加参考图片
if len(referenceImages) > 0 {
opts = append(opts, image.WithReferenceImages(referenceImages))
}
result, err := client.GenerateImage(imageGen.Prompt, opts...)
if err != nil {
s.log.Errorw("Image generation API call failed", "error", err, "id", imageGenID, "prompt", imageGen.Prompt)
s.updateImageGenError(imageGenID, err.Error())
return
}
s.log.Infow("Image generation API call completed", "id", imageGenID, "completed", result.Completed, "has_url", result.ImageURL != "")
if !result.Completed {
s.db.Model(&imageGen).Updates(map[string]interface{}{
"status": models.ImageStatusProcessing,
"task_id": result.TaskID,
})
go s.pollTaskStatus(imageGenID, client, result.TaskID)
return
}
s.completeImageGeneration(imageGenID, result)
}
func (s *ImageGenerationService) pollTaskStatus(imageGenID uint, client image.ImageClient, taskID string) {
maxAttempts := 60
pollInterval := 5 * time.Second
for i := 0; i < maxAttempts; i++ {
time.Sleep(pollInterval)
result, err := client.GetTaskStatus(taskID)
if err != nil {
s.log.Errorw("Failed to get task status", "error", err, "task_id", taskID)
continue
}
if result.Completed {
s.completeImageGeneration(imageGenID, result)
return
}
if result.Error != "" {
s.updateImageGenError(imageGenID, result.Error)
return
}
}
s.updateImageGenError(imageGenID, "timeout: image generation took too long")
}
func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result *image.ImageResult) {
now := time.Now()
// 下载图片到本地存储(仅用于缓存,不更新数据库)
// 仅下载 HTTP/HTTPS URL跳过 data URI
if s.localStorage != nil && result.ImageURL != "" &&
(strings.HasPrefix(result.ImageURL, "http://") || strings.HasPrefix(result.ImageURL, "https://")) {
_, err := s.localStorage.DownloadFromURL(result.ImageURL, "images")
if err != nil {
errStr := err.Error()
if len(errStr) > 200 {
errStr = errStr[:200] + "..."
}
s.log.Warnw("Failed to download image to local storage",
"error", errStr,
"id", imageGenID,
"original_url", truncateImageURL(result.ImageURL))
} else {
s.log.Infow("Image downloaded to local storage for caching",
"id", imageGenID,
"original_url", truncateImageURL(result.ImageURL))
}
}
// 数据库中保持使用原始URL
updates := map[string]interface{}{
"status": models.ImageStatusCompleted,
"image_url": result.ImageURL,
"completed_at": now,
}
if result.Width > 0 {
updates["width"] = result.Width
}
if result.Height > 0 {
updates["height"] = result.Height
}
// 更新image_generation记录
var imageGen models.ImageGeneration
if err := s.db.Where("id = ?", imageGenID).First(&imageGen).Error; err != nil {
s.log.Errorw("Failed to load image generation", "error", err, "id", imageGenID)
return
}
s.db.Model(&models.ImageGeneration{}).Where("id = ?", imageGenID).Updates(updates)
s.log.Infow("Image generation completed", "id", imageGenID)
// 如果关联了storyboard同步更新storyboard的composed_image
if imageGen.StoryboardID != nil {
if err := s.db.Model(&models.Storyboard{}).Where("id = ?", *imageGen.StoryboardID).Update("composed_image", result.ImageURL).Error; err != nil {
s.log.Errorw("Failed to update storyboard composed_image", "error", err, "storyboard_id", *imageGen.StoryboardID)
} else {
s.log.Infow("Storyboard updated with composed image",
"storyboard_id", *imageGen.StoryboardID,
"composed_image", truncateImageURL(result.ImageURL))
}
}
// 如果关联了scene同步更新scene的image_url和status仅当ImageType是scene时
if imageGen.SceneID != nil && imageGen.ImageType == string(models.ImageTypeScene) {
sceneUpdates := map[string]interface{}{
"status": "generated",
"image_url": result.ImageURL,
}
if err := s.db.Model(&models.Scene{}).Where("id = ?", *imageGen.SceneID).Updates(sceneUpdates).Error; err != nil {
s.log.Errorw("Failed to update scene", "error", err, "scene_id", *imageGen.SceneID)
} else {
s.log.Infow("Scene updated with generated image",
"scene_id", *imageGen.SceneID,
"image_url", truncateImageURL(result.ImageURL))
}
}
// 如果关联了角色同步更新角色的image_url
if imageGen.CharacterID != nil {
if err := s.db.Model(&models.Character{}).Where("id = ?", *imageGen.CharacterID).Update("image_url", result.ImageURL).Error; err != nil {
s.log.Errorw("Failed to update character image_url", "error", err, "character_id", *imageGen.CharacterID)
} else {
s.log.Infow("Character updated with generated image",
"character_id", *imageGen.CharacterID,
"image_url", truncateImageURL(result.ImageURL))
}
}
}
func (s *ImageGenerationService) updateImageGenError(imageGenID uint, errorMsg string) {
// 先获取image_generation记录
var imageGen models.ImageGeneration
if err := s.db.Where("id = ?", imageGenID).First(&imageGen).Error; err != nil {
s.log.Errorw("Failed to load image generation", "error", err, "id", imageGenID)
return
}
// 更新image_generation状态
s.db.Model(&models.ImageGeneration{}).Where("id = ?", imageGenID).Updates(map[string]interface{}{
"status": models.ImageStatusFailed,
"error_msg": errorMsg,
})
s.log.Errorw("Image generation failed", "id", imageGenID, "error", errorMsg)
// 如果关联了scene同步更新scene为失败状态
if imageGen.SceneID != nil {
s.db.Model(&models.Scene{}).Where("id = ?", *imageGen.SceneID).Update("status", "failed")
s.log.Warnw("Scene marked as failed", "scene_id", *imageGen.SceneID)
}
}
func (s *ImageGenerationService) getImageClient(provider string) (image.ImageClient, error) {
config, err := s.aiService.GetDefaultConfig("image")
if err != nil {
return nil, fmt.Errorf("no image AI config found: %w", err)
}
// 使用第一个模型
model := ""
if len(config.Model) > 0 {
model = config.Model[0]
}
// 使用配置中的 provider如果没有则使用传入的 provider
actualProvider := config.Provider
if actualProvider == "" {
actualProvider = provider
}
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch actualProvider {
case "openai", "dalle":
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "chatfire":
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "volcengine", "volces", "doubao":
endpoint = "/images/generations"
queryEndpoint = ""
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
}
// getImageClientWithModel 根据模型名称获取图片客户端
func (s *ImageGenerationService) getImageClientWithModel(provider string, modelName string) (image.ImageClient, error) {
var config *models.AIServiceConfig
var err error
// 如果指定了模型,尝试获取对应的配置
if modelName != "" {
config, err = s.aiService.GetConfigForModel("image", modelName)
if err != nil {
s.log.Warnw("Failed to get config for model, using default", "model", modelName, "error", err)
config, err = s.aiService.GetDefaultConfig("image")
if err != nil {
return nil, fmt.Errorf("no image AI config found: %w", err)
}
}
} else {
config, err = s.aiService.GetDefaultConfig("image")
if err != nil {
return nil, fmt.Errorf("no image AI config found: %w", err)
}
}
// 使用指定的模型或配置中的第一个模型
model := modelName
if model == "" && len(config.Model) > 0 {
model = config.Model[0]
}
// 使用配置中的 provider如果没有则使用传入的 provider
actualProvider := config.Provider
if actualProvider == "" {
actualProvider = provider
}
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch actualProvider {
case "openai", "dalle":
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "chatfire":
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "volcengine", "volces", "doubao":
endpoint = "/images/generations"
queryEndpoint = ""
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
}
func (s *ImageGenerationService) GetImageGeneration(imageGenID uint) (*models.ImageGeneration, error) {
var imageGen models.ImageGeneration
if err := s.db.Where("id = ? ", imageGenID).First(&imageGen).Error; err != nil {
return nil, err
}
return &imageGen, nil
}
func (s *ImageGenerationService) ListImageGenerations(dramaID *uint, sceneID *uint, storyboardID *uint, frameType string, status string, page, pageSize int) ([]models.ImageGeneration, int64, error) {
query := s.db.Model(&models.ImageGeneration{})
if dramaID != nil {
query = query.Where("drama_id = ?", *dramaID)
}
if sceneID != nil {
query = query.Where("scene_id = ?", *sceneID)
}
if storyboardID != nil {
query = query.Where("storyboard_id = ?", *storyboardID)
}
if frameType != "" {
query = query.Where("frame_type = ?", frameType)
}
if status != "" {
query = query.Where("status = ?", status)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
var images []models.ImageGeneration
offset := (page - 1) * pageSize
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&images).Error; err != nil {
return nil, 0, err
}
return images, total, nil
}
func (s *ImageGenerationService) DeleteImageGeneration(imageGenID uint) error {
result := s.db.Where("id = ? ", imageGenID).Delete(&models.ImageGeneration{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("image generation not found")
}
return nil
}
func (s *ImageGenerationService) GenerateImagesForScene(sceneID string) ([]*models.ImageGeneration, error) {
// 转换sceneID
sid, err := strconv.ParseUint(sceneID, 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid scene ID")
}
sceneIDUint := uint(sid)
var scene models.Scene
if err := s.db.Where("id = ?", sceneIDUint).First(&scene).Error; err != nil {
return nil, fmt.Errorf("scene not found")
}
// 构建场景图片生成提示词
prompt := scene.Prompt
if prompt == "" {
// 如果Prompt为空使用Location和Time构建
prompt = fmt.Sprintf("%s场景%s", scene.Location, scene.Time)
}
req := &GenerateImageRequest{
SceneID: &sceneIDUint,
DramaID: fmt.Sprintf("%d", scene.DramaID),
ImageType: string(models.ImageTypeScene),
Prompt: prompt,
}
imageGen, err := s.GenerateImage(req)
if err != nil {
return nil, err
}
return []*models.ImageGeneration{imageGen}, nil
}
// BackgroundInfo 背景信息结构
type BackgroundInfo struct {
Location string `json:"location"`
Time string `json:"time"`
Atmosphere string `json:"atmosphere"`
Prompt string `json:"prompt"`
StoryboardNumbers []int `json:"storyboard_numbers"`
SceneIDs []uint `json:"scene_ids"`
StoryboardCount int `json:"scene_count"`
}
func (s *ImageGenerationService) BatchGenerateImagesForEpisode(episodeID string) ([]*models.ImageGeneration, error) {
var ep models.Episode
if err := s.db.Preload("Drama").Where("id = ?", episodeID).First(&ep).Error; err != nil {
return nil, fmt.Errorf("episode not found")
}
// 从数据库读取已保存的场景
var scenes []models.Storyboard
if err := s.db.Where("episode_id = ?", episodeID).Find(&scenes).Error; err != nil {
return nil, fmt.Errorf("failed to get scenes: %w", err)
}
backgrounds := s.extractUniqueBackgrounds(scenes)
s.log.Infow("Extracted unique backgrounds",
"episode_id", episodeID,
"background_count", len(backgrounds))
// 为每个背景生成图片
var results []*models.ImageGeneration
for _, bg := range scenes {
if bg.ImagePrompt == nil || *bg.ImagePrompt == "" {
s.log.Warnw("Background has no prompt, skipping", "scene_id", bg.ID)
continue
}
// 更新背景状态为处理中
s.db.Model(bg).Update("status", "generating")
req := &GenerateImageRequest{
StoryboardID: &bg.ID,
DramaID: fmt.Sprintf("%d", ep.DramaID),
Prompt: *bg.ImagePrompt,
}
imageGen, err := s.GenerateImage(req)
if err != nil {
s.log.Errorw("Failed to generate image for background",
"scene_id", bg.ID,
"location", bg.Location,
"error", err)
s.db.Model(bg).Update("status", "failed")
continue
}
s.log.Infow("Background image generation started",
"scene_id", bg.ID,
"image_gen_id", imageGen.ID,
"location", bg.Location,
"time", bg.Time)
results = append(results, imageGen)
}
return results, nil
}
// GetScencesForEpisode 获取项目的场景列表(项目级)
func (s *ImageGenerationService) GetScencesForEpisode(episodeID string) ([]*models.Scene, error) {
var episode models.Episode
if err := s.db.Preload("Drama").Where("id = ?", episodeID).First(&episode).Error; err != nil {
return nil, fmt.Errorf("episode not found")
}
// 场景是项目级的通过drama_id查询
var scenes []*models.Scene
if err := s.db.Where("drama_id = ?", episode.DramaID).Order("location ASC, time ASC").Find(&scenes).Error; err != nil {
return nil, fmt.Errorf("failed to load scenes: %w", err)
}
return scenes, nil
}
// ExtractBackgroundsForEpisode 从剧本内容中提取场景并保存到项目级别数据库
func (s *ImageGenerationService) ExtractBackgroundsForEpisode(episodeID string) ([]*models.Scene, error) {
var episode models.Episode
if err := s.db.Preload("Drama").Where("id = ?", episodeID).First(&episode).Error; err != nil {
return nil, fmt.Errorf("episode not found")
}
// 检查是否有剧本内容
if episode.ScriptContent == nil || *episode.ScriptContent == "" {
return nil, fmt.Errorf("剧本内容为空,无法提取场景")
}
dramaID := episode.DramaID
// 使用AI从剧本内容中提取场景
backgroundsInfo, err := s.extractBackgroundsFromScript(*episode.ScriptContent, dramaID)
if err != nil {
s.log.Errorw("Failed to extract backgrounds from script", "error", err)
return nil, err
}
// 保存到数据库不涉及Storyboard关联因为此时还没有生成分镜
var scenes []*models.Scene
err = s.db.Transaction(func(tx *gorm.DB) error {
// 先删除该章节的所有场景(实现重新提取覆盖功能)
if err := tx.Where("episode_id = ?", episode.ID).Delete(&models.Scene{}).Error; err != nil {
s.log.Errorw("Failed to delete old scenes", "error", err)
return err
}
s.log.Infow("Deleted old scenes for re-extraction", "episode_id", episode.ID)
// 创建新提取的场景
for _, bgInfo := range backgroundsInfo {
// 保存新场景到数据库(章节级)
episodeIDVal := episode.ID
scene := &models.Scene{
DramaID: dramaID,
EpisodeID: &episodeIDVal,
Location: bgInfo.Location,
Time: bgInfo.Time,
Prompt: bgInfo.Prompt,
StoryboardCount: 1, // 默认为1
Status: "pending",
}
if err := tx.Create(scene).Error; err != nil {
return err
}
scenes = append(scenes, scene)
s.log.Infow("Created new scene from script",
"scene_id", scene.ID,
"location", scene.Location,
"time", scene.Time)
}
return nil
})
if err != nil {
return nil, err
}
s.log.Infow("Saved scenes to database",
"episode_id", episodeID,
"total_storyboards", len(episode.Storyboards),
"unique_scenes", len(scenes))
return scenes, nil
}
// extractBackgroundsFromScript 从剧本内容中使用AI提取场景信息
func (s *ImageGenerationService) extractBackgroundsFromScript(scriptContent string, dramaID uint) ([]BackgroundInfo, error) {
if scriptContent == "" {
return []BackgroundInfo{}, nil
}
// 获取AI客户端
client, err := s.aiService.GetAIClient("text")
if err != nil {
return nil, fmt.Errorf("failed to get AI client: %w", err)
}
// 构建AI提示词
prompt := fmt.Sprintf(`【任务】分析以下剧本内容,提取出所有需要的场景背景信息。
【剧本内容】
%s
【要求】
1. 识别剧本中所有不同的场景(地点+时间组合)
2. 为每个场景生成详细的**中文**图片生成提示词Prompt
3. **重要**:场景描述必须是**纯背景**,不能包含人物、角色、动作等元素
4. Prompt要求
- **必须使用中文**,不能包含英文字符
- 详细描述场景环境、建筑、物品、光线、氛围等
- **禁止描述人物、角色、动作、对话等**
- 适合AI图片生成模型使用
- 风格统一为:电影感、细节丰富、动漫风格、高质量
5. location、time、atmosphere和prompt字段都使用中文
6. 提取场景的氛围描述atmosphere
【输出JSON格式】
{
"backgrounds": [
{
"location": "地点名称(中文)",
"time": "时间描述(中文)",
"atmosphere": "氛围描述(中文)",
"prompt": "一个电影感的动漫风格纯背景场景,展现[地点描述]在[时间]的环境。画面呈现[环境细节、建筑、物品、光线等,不包含人物]。风格:细节丰富,高质量,氛围光照。情绪:[环境情绪描述]。"
}
]
}
【示例】
正确示例(注意:不包含人物):
{
"backgrounds": [
{
"location": "维修店内部",
"time": "深夜",
"atmosphere": "昏暗、孤独、工业感",
"prompt": "一个电影感的动漫风格纯背景场景,展现凌乱的维修店内部在深夜的环境。昏暗的日光灯照射下,工作台上散落着各种扳手、螺丝刀和机械零件,墙上挂着油污斑斑的工具挂板和褪色海报,地面有油渍痕迹,角落堆放着废旧轮胎。风格:细节丰富,高质量,昏暗氛围。情绪:孤独、工业感。"
},
{
"location": "城市街道",
"time": "黄昏",
"atmosphere": "温暖、繁忙、生活气息",
"prompt": "一个电影感的动漫风格纯背景场景,展现繁华的城市街道在黄昏时分的环境。夕阳的余晖洒在街道的沥青路面上,两旁的商铺霓虹灯开始点亮,街边有自行车停靠架和公交站牌,远处高楼林立,天空呈现橙红色渐变。风格:细节丰富,高质量,温暖氛围。情绪:生活气息、繁忙。"
}
]
}
【错误示例(包含人物,禁止)】:
❌ "展现主角站在街道上的场景" - 包含人物
❌ "人们匆匆而过" - 包含人物
❌ "角色在房间里活动" - 包含人物
请严格按照JSON格式输出确保所有字段都使用中文。`, scriptContent)
response, err := client.GenerateText(prompt, "", ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
if err != nil {
s.log.Errorw("Failed to extract backgrounds with AI", "error", err)
return nil, fmt.Errorf("AI提取场景失败: %w", err)
}
s.log.Infow("AI backgrounds extraction response", "length", len(response))
// 解析JSON响应
var result struct {
Backgrounds []BackgroundInfo `json:"backgrounds"`
}
if err := utils.SafeParseAIJSON(response, &result); err != nil {
s.log.Errorw("Failed to parse AI response", "error", err, "response", response[:minInt(500, len(response))])
return nil, fmt.Errorf("解析AI响应失败: %w", err)
}
s.log.Infow("Extracted backgrounds from script",
"drama_id", dramaID,
"backgrounds_count", len(result.Backgrounds))
return result.Backgrounds, nil
}
// extractBackgroundsWithAI 使用AI智能分析场景并提取唯一背景
func (s *ImageGenerationService) extractBackgroundsWithAI(storyboards []models.Storyboard) ([]BackgroundInfo, error) {
if len(storyboards) == 0 {
return []BackgroundInfo{}, nil
}
// 构建场景列表文本使用SceneNumber而不是索引
var scenesText string
for _, storyboard := range storyboards {
location := ""
if storyboard.Location != nil {
location = *storyboard.Location
}
time := ""
if storyboard.Time != nil {
time = *storyboard.Time
}
action := ""
if storyboard.Action != nil {
action = *storyboard.Action
}
description := ""
if storyboard.Description != nil {
description = *storyboard.Description
}
scenesText += fmt.Sprintf("镜头%d:\n地点: %s\n时间: %s\n动作: %s\n描述: %s\n\n",
storyboard.StoryboardNumber, location, time, action, description)
}
// 构建AI提示词
prompt := fmt.Sprintf(`【任务】分析以下分镜头场景,提取出所有需要生成的唯一背景,并返回每个背景对应的场景编号。
【分镜头列表】
%s
【要求】
1. 合并相同或相似的场景背景(地点和时间相同或相近)
2. 为每个唯一背景生成**中文**图片生成提示词Prompt
3. Prompt要求
- **必须使用中文**,不能包含英文字符
- 详细描述场景、时间、氛围、风格
- 适合AI图片生成模型使用
- 风格统一为:电影感、细节丰富、动漫风格、高质量
4. **重要**必须返回使用该背景的场景编号数组scene_numbers
5. location、time和prompt字段都使用中文
6. 每个场景都必须分配到某个背景,确保所有场景编号都被包含
【输出JSON格式】
{
"backgrounds": [
{
"location": "地点名称(中文)",
"time": "时间描述(中文)",
"prompt": "一个电影感的动漫风格背景,展现[地点描述]在[时间]的场景。画面呈现[细节描述]。风格:细节丰富,高质量,氛围光照。情绪:[情绪描述]。",
"scene_numbers": [1, 2, 3]
}
]
}
【示例】
正确示例:
{
"backgrounds": [
{
"location": "维修店",
"time": "深夜",
"prompt": "一个电影感的动漫风格背景,展现凌乱的维修店内部在深夜的场景。昏暗的灯光下,工作台上散落着各种工具和零件,墙上挂着油污的海报。风格:细节丰富,高质量,昏暗氛围。情绪:孤独、工业感。",
"scene_numbers": [1, 5, 6, 10, 15]
},
{
"location": "城市全景",
"time": "深夜·酸雨",
"prompt": "一个电影感的动漫风格背景,展现沿海城市全景在深夜酸雨中的场景。霓虹灯在雨中模糊,高楼大厦笼罩在灰绿色的雨幕中,街道反射着五颜六色的光。风格:细节丰富,高质量,赛博朋克氛围。情绪:压抑、科幻、末世感。",
"scene_numbers": [2, 7]
}
]
}
请严格按照JSON格式输出确保
1. prompt字段使用中文
2. scene_numbers包含所有使用该背景的场景编号
3. 所有场景都被分配到某个背景`, scenesText)
// 调用AI服务
text, err := s.aiService.GenerateText(prompt, "")
if err != nil {
return nil, fmt.Errorf("AI analysis failed: %w", err)
}
// 解析AI返回的JSON
var result struct {
Scenes []struct {
Location string `json:"location"`
Time string `json:"time"`
Prompt string `json:"prompt"`
StoryboardNumber []int `json:"storyboard_number"`
} `json:"backgrounds"`
}
if err := utils.SafeParseAIJSON(text, &result); err != nil {
return nil, fmt.Errorf("failed to parse AI response: %w", err)
}
// 构建场景编号到场景ID的映射
storyboardNumberToID := make(map[int]uint)
for _, scene := range storyboards {
storyboardNumberToID[scene.StoryboardNumber] = scene.ID
}
// 转换为BackgroundInfo
var backgrounds []BackgroundInfo
for _, bg := range result.Scenes {
// 将场景编号转换为场景ID
var sceneIDs []uint
for _, storyboardNum := range bg.StoryboardNumber {
if storyboardID, ok := storyboardNumberToID[storyboardNum]; ok {
sceneIDs = append(sceneIDs, storyboardID)
}
}
backgrounds = append(backgrounds, BackgroundInfo{
Location: bg.Location,
Time: bg.Time,
Prompt: bg.Prompt,
StoryboardNumbers: bg.StoryboardNumber,
SceneIDs: sceneIDs,
StoryboardCount: len(sceneIDs),
})
}
s.log.Infow("AI extracted backgrounds",
"total_scenes", len(storyboards),
"extracted_backgrounds", len(backgrounds))
return backgrounds, nil
}
// extractUniqueBackgrounds 从分镜头中提取唯一背景代码逻辑作为AI提取的备份
func (s *ImageGenerationService) extractUniqueBackgrounds(scenes []models.Storyboard) []BackgroundInfo {
backgroundMap := make(map[string]*BackgroundInfo)
for _, scene := range scenes {
if scene.Location == nil || scene.Time == nil {
continue
}
// 使用 location + time 作为唯一标识
key := *scene.Location + "|" + *scene.Time
if bg, exists := backgroundMap[key]; exists {
// 背景已存在添加scene ID
bg.SceneIDs = append(bg.SceneIDs, scene.ID)
bg.StoryboardCount++
} else {
// 新背景 - 使用ImagePrompt构建背景提示词
prompt := ""
if scene.ImagePrompt != nil {
prompt = *scene.ImagePrompt
}
backgroundMap[key] = &BackgroundInfo{
Location: *scene.Location,
Time: *scene.Time,
Prompt: prompt,
SceneIDs: []uint{scene.ID},
StoryboardCount: 1,
}
}
}
// 转换为切片
var backgrounds []BackgroundInfo
for _, bg := range backgroundMap {
backgrounds = append(backgrounds, *bg)
}
return backgrounds
}

View File

@@ -0,0 +1,21 @@
package services
import (
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
type ResourceTransferService struct {
db *gorm.DB
log *logger.Logger
}
func NewResourceTransferService(db *gorm.DB, log *logger.Logger) *ResourceTransferService {
return &ResourceTransferService{
db: db,
log: log,
}
}
// ResourceTransferService 现在只保留基本结构MinIO相关功能已移除
// 如需资源转存功能,请使用本地存储

View File

@@ -0,0 +1,511 @@
package services
import (
"encoding/json"
"fmt"
"strconv"
"github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/ai"
"github.com/drama-generator/backend/pkg/logger"
"github.com/drama-generator/backend/pkg/utils"
"gorm.io/gorm"
)
type ScriptGenerationService struct {
db *gorm.DB
aiService *AIService
log *logger.Logger
}
func NewScriptGenerationService(db *gorm.DB, log *logger.Logger) *ScriptGenerationService {
return &ScriptGenerationService{
db: db,
aiService: NewAIService(db, log),
log: log,
}
}
type GenerateOutlineRequest struct {
DramaID string `json:"drama_id" binding:"required"`
Theme string `json:"theme" binding:"required,min=2,max=500"`
Genre string `json:"genre"`
Style string `json:"style"`
Length int `json:"length"`
Temperature float64 `json:"temperature"`
}
type GenerateCharactersRequest struct {
DramaID string `json:"drama_id" binding:"required"`
Outline string `json:"outline"`
Count int `json:"count"`
Temperature float64 `json:"temperature"`
}
type GenerateEpisodesRequest struct {
DramaID string `json:"drama_id" binding:"required"`
Outline string `json:"outline"`
EpisodeCount int `json:"episode_count" binding:"required,min=1,max=100"`
Temperature float64 `json:"temperature"`
}
type OutlineResult struct {
Title string `json:"title"`
Summary string `json:"summary"`
Genre string `json:"genre"`
Tags []string `json:"tags"`
Characters []CharacterOutline `json:"characters"`
Episodes []EpisodeOutline `json:"episodes"`
KeyScenes []string `json:"key_scenes"`
}
type CharacterOutline struct {
Name string `json:"name"`
Role string `json:"role"`
Description string `json:"description"`
Personality string `json:"personality"`
Appearance string `json:"appearance"`
}
type EpisodeOutline struct {
EpisodeNumber int `json:"episode_number"`
Title string `json:"title"`
Summary string `json:"summary"`
Scenes []string `json:"scenes"`
Duration int `json:"duration"`
}
func (s *ScriptGenerationService) GenerateOutline(req *GenerateOutlineRequest) (*OutlineResult, error) {
var drama models.Drama
if err := s.db.Where("id = ?", req.DramaID).First(&drama).Error; err != nil {
return nil, fmt.Errorf("drama not found")
}
systemPrompt := `你是专业短剧编剧。根据主题和剧集数量,创作完整的短剧大纲,规划好每一集的剧情走向。
要求:
1. 剧情紧凑,矛盾冲突强烈,节奏快
2. 必须规划好每一集的核心剧情
3. 每集有明确冲突和转折点,集与集之间有连贯性和悬念
**重要必须输出完整有效的JSON确保所有字段完整特别是episodes数组必须完整闭合**
JSON格式紧凑summary和episodes字段必须完整
{"title":"剧名","summary":"200-250字剧情概述包含故事背景、主要矛盾、核心冲突、完整走向","genre":"类型","tags":["标签1","标签2","标签3"],"episodes":[{"episode_number":1,"title":"标题","summary":"80字剧情概要"},{"episode_number":2,"title":"标题","summary":"80字剧情概要"}],"key_scenes":["场景1","场景2","场景3"]}
关键要求:
- summary控制在200-250字简洁清晰
- episodes必须生成用户要求的完整集数
- 每集summary控制在80字左右
- 确保JSON完整闭合不要截断
- 不要添加任何JSON外的文字说明`
userPrompt := fmt.Sprintf(`请为以下主题创作短剧大纲:
主题:%s`, req.Theme)
if req.Genre != "" {
userPrompt += fmt.Sprintf("\n类型偏好%s", req.Genre)
}
if req.Style != "" {
userPrompt += fmt.Sprintf("\n风格要求%s", req.Style)
}
length := req.Length
if length == 0 {
length = 5
}
userPrompt += fmt.Sprintf("\n剧集数量%d集", length)
userPrompt += fmt.Sprintf("\n\n**重要必须在episodes数组中规划完整的%d集剧情每集都要有明确的故事内容**", length)
temperature := req.Temperature
if temperature == 0 {
temperature = 0.8
}
// 调整token限制基础2000 + 每集约150 tokens包含80-100字概要
maxTokens := 2000 + (length * 150)
if maxTokens > 8000 {
maxTokens = 8000
}
s.log.Infow("Generating outline with episodes",
"episode_count", length,
"max_tokens", maxTokens)
text, err := s.aiService.GenerateText(
userPrompt,
systemPrompt,
ai.WithTemperature(temperature),
ai.WithMaxTokens(maxTokens),
)
if err != nil {
s.log.Errorw("Failed to generate outline", "error", err)
return nil, fmt.Errorf("生成失败: %w", err)
}
s.log.Infow("AI response received", "length", len(text), "preview", text[:minInt(200, len(text))])
var result OutlineResult
if err := utils.SafeParseAIJSON(text, &result); err != nil {
s.log.Errorw("Failed to parse outline JSON", "error", err, "raw_response", text[:minInt(500, len(text))])
return nil, fmt.Errorf("解析 AI 返回结果失败: %w", err)
}
// 将Tags转换为JSON格式存储
tagsJSON, err := json.Marshal(result.Tags)
if err != nil {
s.log.Errorw("Failed to marshal tags", "error", err)
tagsJSON = []byte("[]")
}
if err := s.db.Model(&drama).Updates(map[string]interface{}{
"title": result.Title,
"description": result.Summary,
"genre": result.Genre,
"tags": tagsJSON,
}).Error; err != nil {
s.log.Errorw("Failed to update drama", "error", err)
}
s.log.Infow("Outline generated", "drama_id", req.DramaID)
return &result, nil
}
func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequest) ([]models.Character, error) {
var drama models.Drama
if err := s.db.Where("id = ? ", req.DramaID).First(&drama).Error; err != nil {
return nil, fmt.Errorf("drama not found")
}
count := req.Count
if count == 0 {
count = 5
}
systemPrompt := `你是一个专业的角色分析师,擅长从剧本中提取和分析角色信息。
你的任务是根据提供的剧本内容,提取并整理剧中出现的所有角色的详细设定。
要求:
1. 仔细阅读剧本,识别所有出现的角色
2. 根据剧本中的对话、行为和描述,总结角色的性格特点
3. 提取角色在剧本中的关键信息:背景、动机、目标、关系等
4. 角色之间的关系必须基于剧本中的实际描述
5. 外貌描述必须极其详细如果剧本中有描述则使用如果没有则根据角色设定合理推断便于AI绘画生成角色形象
6. 优先提取主要角色和重要配角,次要角色可以简略
请严格按照以下 JSON 格式输出,不要添加任何其他文字:
{
"characters": [
{
"name": "角色名",
"role": "主角/重要配角/配角",
"description": "角色背景和简介200-300字包括出身背景、成长经历、核心动机、与其他角色的关系、在故事中的作用",
"personality": "性格特点详细描述100-150字包括主要性格特征、行为习惯、价值观、优点缺点、情绪表达方式、对待他人的态度等",
"appearance": "外貌描述极其详细150-200字必须包括确切年龄、精确身高、体型身材、肤色质感、发型发色发长、眼睛颜色形状、面部特征如眉毛、鼻子、嘴唇、着装风格、服装颜色材质、配饰细节、标志性特征、整体气质风格等描述要具体到可以直接用于AI绘画",
"voice_style": "说话风格和语气特点详细描述50-80字包括语速语调、用词习惯、口头禅、说话时的情绪特征等"
}
]
}
注意:
- 必须基于剧本内容提取角色,不要凭空创作
- 优先提取主要角色和重要配角,数量根据剧本实际情况确定
- description、personality、appearance、voice_style都必须详细描述字数要充足
- appearance外貌描述是重中之重必须极其详细具体要能让AI准确生成角色形象
- 如果剧本中角色信息不完整,可以根据角色设定合理补充,但要符合剧本整体风格`
outlineText := req.Outline
if outlineText == "" {
outlineText = fmt.Sprintf("剧名:%s\n简介%s\n类型%s", drama.Title, drama.Description, drama.Genre)
}
userPrompt := fmt.Sprintf(`剧本内容:
%s
请从剧本中提取并整理最多 %d 个主要角色的详细设定。`, outlineText, count)
temperature := req.Temperature
if temperature == 0 {
temperature = 0.7
}
text, err := s.aiService.GenerateText(
userPrompt,
systemPrompt,
ai.WithTemperature(temperature),
ai.WithMaxTokens(3000),
)
if err != nil {
s.log.Errorw("Failed to generate characters", "error", err)
return nil, fmt.Errorf("生成失败: %w", err)
}
s.log.Infow("AI response received", "length", len(text), "preview", text[:minInt(200, len(text))])
var result struct {
Characters []struct {
Name string `json:"name"`
Role string `json:"role"`
Description string `json:"description"`
Personality string `json:"personality"`
Appearance string `json:"appearance"`
VoiceStyle string `json:"voice_style"`
} `json:"characters"`
}
if err := utils.SafeParseAIJSON(text, &result); err != nil {
s.log.Errorw("Failed to parse characters JSON", "error", err, "raw_response", text[:minInt(500, len(text))])
return nil, fmt.Errorf("解析 AI 返回结果失败: %w", err)
}
var characters []models.Character
for _, char := range result.Characters {
// 检查角色是否已存在
var existingChar models.Character
err := s.db.Where("drama_id = ? AND name = ?", req.DramaID, char.Name).First(&existingChar).Error
if err == nil {
// 角色已存在,直接使用已存在的角色,不覆盖
s.log.Infow("Character already exists, skipping", "drama_id", req.DramaID, "name", char.Name)
characters = append(characters, existingChar)
continue
}
// 角色不存在,创建新角色
dramaID, _ := strconv.ParseUint(req.DramaID, 10, 32)
character := models.Character{
DramaID: uint(dramaID),
Name: char.Name,
Role: &char.Role,
Description: &char.Description,
Personality: &char.Personality,
Appearance: &char.Appearance,
VoiceStyle: &char.VoiceStyle,
}
if err := s.db.Create(&character).Error; err != nil {
s.log.Errorw("Failed to create character", "error", err)
continue
}
characters = append(characters, character)
}
s.log.Infow("Characters generated", "drama_id", req.DramaID, "total_count", len(characters), "new_count", len(characters))
return characters, nil
}
func (s *ScriptGenerationService) GenerateEpisodes(req *GenerateEpisodesRequest) ([]models.Episode, error) {
var drama models.Drama
if err := s.db.Where("id = ? ", req.DramaID).First(&drama).Error; err != nil {
return nil, fmt.Errorf("drama not found")
}
// 获取角色信息
var characters []models.Character
s.db.Where("drama_id = ?", req.DramaID).Find(&characters)
var characterList string
if len(characters) > 0 {
characterList = "\n角色设定\n"
for _, char := range characters {
characterList += fmt.Sprintf("- %s", char.Name)
if char.Role != nil {
characterList += fmt.Sprintf("%s", *char.Role)
}
if char.Description != nil {
characterList += fmt.Sprintf("%s", *char.Description)
}
if char.Personality != nil {
characterList += fmt.Sprintf(" | 性格:%s", *char.Personality)
}
characterList += "\n"
}
} else {
characterList = "\n注意尚未设定角色请根据大纲创作合理的角色出场\n"
}
systemPrompt := `你是一个专业的短剧编剧。你擅长根据分集规划创作详细的剧情内容。
你的任务是根据大纲中的分集规划将每一集的概要扩展为详细的剧情叙述。每集约180秒3分钟需要充实的内容。
工作流程:
1. 大纲中已提供每集的剧情规划80-100字概要
2. 你需要将每集概要扩展为400-500字的详细剧情叙述
3. 严格按照分集规划的数量和走向展开,不能遗漏任何一集
详细要求:
1. script_content用400-500字详细叙述包括
- 具体场景和环境描写
- 角色的行动、对话要点、情绪变化
- 冲突的产生过程和激化细节
- 关键情节点和转折
- 为下一集埋下的伏笔
2. 每集有明确的冲突和转折点
3. 集与集之间有连贯性和悬念
4. 充分展现角色性格和关系演变
5. 内容详实足以支撑180秒时长
JSON格式紧凑
{"episodes":[{"episode_number":1,"title":"标题","description":"简短梗概","script_content":"400-500字详细剧情叙述","duration":210}]}
格式说明:
1. script_content为叙述文不是场景对话格式
2. 每集包含开场铺垫、冲突发展、高潮转折、结局悬念
3. duration根据剧情复杂度设置在150-300秒
关键要求:
- 大纲规划了几集就必须生成几集
- 严格按照分集规划的故事线展开
- 每一集都要有完整的400-500字详细内容
- 绝对不能遗漏任何一集`
outlineText := req.Outline
if outlineText == "" {
outlineText = fmt.Sprintf("剧名:%s\n简介%s\n类型%s", drama.Title, drama.Description, drama.Genre)
}
userPrompt := fmt.Sprintf(`剧本大纲:
%s
%s
请基于以上大纲和角色,创作 %d 集的详细剧本。
**重要要求:**
- 必须生成完整的 %d 集从第1集到第%d集不能遗漏
- 每集约3-5分钟150-300秒
- 每集的duration字段要根据剧本内容长度合理设置不要都设置为同一个值
- 返回的JSON中episodes数组必须包含 %d 个元素`, outlineText, characterList, req.EpisodeCount, req.EpisodeCount, req.EpisodeCount, req.EpisodeCount)
temperature := req.Temperature
if temperature == 0 {
temperature = 0.7
}
// 根据剧集数量调整token限制
// 模型支持128k上下文每集400-500字约需800-1000 tokens包含JSON结构
baseTokens := 3000 // 基础(系统提示+角色列表+大纲)
perEpisodeTokens := 900 // 每集约900 tokens支持400-500字详细内容
maxTokens := baseTokens + (req.EpisodeCount * perEpisodeTokens)
// 128k上下文可以设置较大的token限制
// 10集约12000 tokens20集约21000 tokens都在安全范围内
if maxTokens > 32000 {
maxTokens = 32000 // 保守限制在32k留足够空间
}
s.log.Infow("Generating episodes with token limit",
"episode_count", req.EpisodeCount,
"max_tokens", maxTokens,
"estimated_per_episode", perEpisodeTokens)
text, err := s.aiService.GenerateText(
userPrompt,
systemPrompt,
ai.WithTemperature(0.8),
ai.WithMaxTokens(maxTokens),
)
if err != nil {
s.log.Errorw("Failed to generate episodes", "error", err)
return nil, fmt.Errorf("生成失败: %w", err)
}
s.log.Infow("AI response received", "length", len(text), "preview", text[:minInt(200, len(text))])
var result struct {
Episodes []struct {
EpisodeNumber int `json:"episode_number"`
Title string `json:"title"`
Description string `json:"description"`
ScriptContent string `json:"script_content"`
Duration int `json:"duration"`
} `json:"episodes"`
}
if err := utils.SafeParseAIJSON(text, &result); err != nil {
s.log.Errorw("Failed to parse episodes JSON", "error", err, "raw_response", text[:minInt(500, len(text))])
return nil, fmt.Errorf("解析 AI 返回结果失败: %w", err)
}
// 检查生成的集数是否符合要求
if len(result.Episodes) < req.EpisodeCount {
s.log.Warnw("AI generated fewer episodes than requested",
"requested", req.EpisodeCount,
"generated", len(result.Episodes))
}
// 记录每集的详细信息
for i, ep := range result.Episodes {
s.log.Infow("Episode parsed from AI",
"index", i,
"episode_number", ep.EpisodeNumber,
"title", ep.Title,
"description_length", len(ep.Description),
"script_content_length", len(ep.ScriptContent),
"duration", ep.Duration)
}
var episodes []models.Episode
for _, ep := range result.Episodes {
duration := ep.Duration
if duration == 0 {
// AI未返回时长时使用默认值
duration = 180
s.log.Warnw("Episode duration not provided by AI, using default",
"episode_number", ep.EpisodeNumber,
"default_duration", 180)
} else {
s.log.Infow("Episode duration from AI",
"episode_number", ep.EpisodeNumber,
"duration", duration)
}
// 记录即将保存的数据
s.log.Infow("Creating episode in database",
"episode_number", ep.EpisodeNumber,
"title", ep.Title,
"script_content_length", len(ep.ScriptContent),
"script_content_empty", ep.ScriptContent == "")
dramaID, err := strconv.ParseUint(req.DramaID, 10, 32)
if err != nil {
return nil, fmt.Errorf("invalid drama ID")
}
episode := models.Episode{
DramaID: uint(dramaID),
EpisodeNum: ep.EpisodeNumber,
Title: ep.Title,
Description: &ep.Description,
ScriptContent: &ep.ScriptContent,
Duration: duration,
Status: "draft",
}
if err := s.db.Create(&episode).Error; err != nil {
s.log.Errorw("Failed to create episode", "error", err)
continue
}
episodes = append(episodes, episode)
}
s.log.Infow("Episodes generated", "drama_id", req.DramaID, "count", len(episodes))
return episodes, nil
}
// GenerateScenesForEpisode 已废弃,使用 StoryboardService.GenerateStoryboard 替代
// ParseScript 已废弃,使用 GenerateCharacters 替代
// minInt 返回两个整数中较小的一个
func minInt(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,395 @@
package services
import (
"encoding/json"
"fmt"
models "github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
type StoryboardCompositionService struct {
db *gorm.DB
log *logger.Logger
imageGen *ImageGenerationService
}
func NewStoryboardCompositionService(db *gorm.DB, log *logger.Logger, imageGen *ImageGenerationService) *StoryboardCompositionService {
return &StoryboardCompositionService{
db: db,
log: log,
imageGen: imageGen,
}
}
type SceneCharacterInfo struct {
ID uint `json:"id"`
Name string `json:"name"`
ImageURL *string `json:"image_url,omitempty"`
}
type SceneBackgroundInfo struct {
ID uint `json:"id"`
Location string `json:"location"`
Time string `json:"time"`
ImageURL *string `json:"image_url,omitempty"`
Status string `json:"status"`
}
type SceneCompositionInfo struct {
ID uint `json:"id"`
StoryboardNumber int `json:"storyboard_number"`
Title *string `json:"title"`
Description *string `json:"description"`
Location *string `json:"location"`
Time *string `json:"time"`
Duration int `json:"duration"`
Dialogue *string `json:"dialogue"`
Action *string `json:"action"`
Atmosphere *string `json:"atmosphere"`
ImagePrompt *string `json:"image_prompt,omitempty"`
VideoPrompt *string `json:"video_prompt,omitempty"`
Characters []SceneCharacterInfo `json:"characters"`
Background *SceneBackgroundInfo `json:"background"`
SceneID *uint `json:"scene_id"`
ComposedImage *string `json:"composed_image,omitempty"`
VideoURL *string `json:"video_url,omitempty"`
ImageGenerationID *uint `json:"image_generation_id,omitempty"`
ImageGenerationStatus *string `json:"image_generation_status,omitempty"`
VideoGenerationID *uint `json:"video_generation_id,omitempty"`
VideoGenerationStatus *string `json:"video_generation_status,omitempty"`
}
func (s *StoryboardCompositionService) GetScenesForEpisode(episodeID string) ([]SceneCompositionInfo, error) {
// 验证权限
var episode models.Episode
err := s.db.Preload("Drama").Where("id = ?", episodeID).First(&episode).Error
if err != nil {
s.log.Errorw("Episode not found", "episode_id", episodeID, "error", err)
return nil, fmt.Errorf("episode not found")
}
s.log.Infow("GetScenesForEpisode auth check",
"episode_id", episodeID,
"drama_id", episode.DramaID)
// 获取分镜列表
var storyboards []models.Storyboard
if err := s.db.Where("episode_id = ?", episodeID).
Preload("Characters").
Order("storyboard_number ASC").
Find(&storyboards).Error; err != nil {
return nil, fmt.Errorf("failed to load storyboards: %w", err)
}
// 获取所有角色(用于匹配角色信息)
var characters []models.Character
if err := s.db.Where("drama_id = ?", episode.DramaID).Find(&characters).Error; err != nil {
s.log.Warnw("Failed to load characters", "error", err)
}
// 创建角色ID到角色信息的映射
charIDToInfo := make(map[uint]*models.Character)
for i := range characters {
charIDToInfo[characters[i].ID] = &characters[i]
}
// 获取所有场景ID
var sceneIDs []uint
for _, storyboard := range storyboards {
if storyboard.SceneID != nil {
sceneIDs = append(sceneIDs, *storyboard.SceneID)
}
}
// 批量获取场景信息
var scenes []models.Scene
sceneMap := make(map[uint]*models.Scene)
if len(sceneIDs) > 0 {
if err := s.db.Where("id IN ?", sceneIDs).Find(&scenes).Error; err == nil {
for i := range scenes {
sceneMap[scenes[i].ID] = &scenes[i]
}
}
}
// 获取分镜的合成图片(从 image_generations 表)
storyboardIDs := make([]uint, len(storyboards))
for i, storyboard := range storyboards {
storyboardIDs[i] = storyboard.ID
}
imageGenMap := make(map[uint]string) // storyboard_id -> image_url
imageGenTaskMap := make(map[uint]*models.ImageGeneration) // storyboard_id -> processing task
if len(storyboardIDs) > 0 {
var imageGens []models.ImageGeneration
// 查询已完成的图片生成记录,每个镜头只取最新的一条
if err := s.db.Where("storyboard_id IN ? AND status = ?", storyboardIDs, models.ImageStatusCompleted).
Order("created_at DESC").
Find(&imageGens).Error; err == nil {
// 为每个镜头保留最新的一条记录
for _, ig := range imageGens {
if ig.StoryboardID != nil {
if _, exists := imageGenMap[*ig.StoryboardID]; !exists {
if ig.ImageURL != nil {
imageGenMap[*ig.StoryboardID] = *ig.ImageURL
}
}
}
}
}
// 查询进行中的图片生成任务
var processingImageGens []models.ImageGeneration
if err := s.db.Where("storyboard_id IN ? AND status = ?", storyboardIDs, models.ImageStatusProcessing).
Order("created_at DESC").
Find(&processingImageGens).Error; err == nil {
for _, ig := range processingImageGens {
if ig.StoryboardID != nil {
if _, exists := imageGenTaskMap[*ig.StoryboardID]; !exists {
igCopy := ig
imageGenTaskMap[*ig.StoryboardID] = &igCopy
}
}
}
}
}
// 批量查询进行中的视频生成任务
videoGenTaskMap := make(map[uint]*models.VideoGeneration) // storyboard_id -> processing task
if len(storyboardIDs) > 0 {
var processingVideoGens []models.VideoGeneration
if err := s.db.Where("scene_id IN ? AND status = ?", storyboardIDs, models.VideoStatusProcessing).
Order("created_at DESC").
Find(&processingVideoGens).Error; err == nil {
for _, vg := range processingVideoGens {
if vg.StoryboardID != nil {
if _, exists := videoGenTaskMap[*vg.StoryboardID]; !exists {
vgCopy := vg
videoGenTaskMap[*vg.StoryboardID] = &vgCopy
}
}
}
}
}
// 构建返回结果
var result []SceneCompositionInfo
for _, storyboard := range storyboards {
storyboardInfo := SceneCompositionInfo{
ID: storyboard.ID,
StoryboardNumber: storyboard.StoryboardNumber,
Title: storyboard.Title,
Description: storyboard.Description,
Location: storyboard.Location,
Time: storyboard.Time,
Duration: storyboard.Duration,
Action: storyboard.Action,
Dialogue: storyboard.Dialogue,
Atmosphere: storyboard.Atmosphere,
ImagePrompt: storyboard.ImagePrompt,
VideoPrompt: storyboard.VideoPrompt,
SceneID: storyboard.SceneID,
}
// 直接使用关联的角色信息
if len(storyboard.Characters) > 0 {
for _, char := range storyboard.Characters {
storyboardChar := SceneCharacterInfo{
ID: char.ID,
Name: char.Name,
ImageURL: char.ImageURL,
}
storyboardInfo.Characters = append(storyboardInfo.Characters, storyboardChar)
}
}
// 添加场景信息
if storyboard.SceneID != nil {
if scene, ok := sceneMap[*storyboard.SceneID]; ok {
storyboardInfo.Background = &SceneBackgroundInfo{
ID: scene.ID,
Location: scene.Location,
Time: scene.Time,
ImageURL: scene.ImageURL,
Status: scene.Status,
}
}
}
// 添加合成图片
if imageURL, ok := imageGenMap[storyboard.ID]; ok {
storyboardInfo.ComposedImage = &imageURL
}
// 添加视频URL
if storyboard.VideoURL != nil {
storyboardInfo.VideoURL = storyboard.VideoURL
}
// 添加进行中的图片生成任务信息
if imageTask, ok := imageGenTaskMap[storyboard.ID]; ok {
storyboardInfo.ImageGenerationID = &imageTask.ID
statusStr := string(imageTask.Status)
storyboardInfo.ImageGenerationStatus = &statusStr
}
// 添加进行中的视频生成任务信息
if videoTask, ok := videoGenTaskMap[storyboard.ID]; ok {
storyboardInfo.VideoGenerationID = &videoTask.ID
statusStr := string(videoTask.Status)
storyboardInfo.VideoGenerationStatus = &statusStr
}
result = append(result, storyboardInfo)
}
return result, nil
}
type UpdateSceneRequest struct {
SceneID *uint `json:"scene_id"`
Characters []uint `json:"characters"` // 改为存储角色ID数组
Location *string `json:"location"`
Time *string `json:"time"`
Action *string `json:"action"`
Dialogue *string `json:"dialogue"`
Description *string `json:"description"`
Duration *int `json:"duration"`
ImagePrompt *string `json:"image_prompt"`
VideoPrompt *string `json:"video_prompt"`
}
func (s *StoryboardCompositionService) UpdateScene(sceneID string, req *UpdateSceneRequest) error {
// 获取分镜并验证权限
var storyboard models.Storyboard
err := s.db.Preload("Episode.Drama").Where("id = ?", sceneID).First(&storyboard).Error
if err != nil {
return fmt.Errorf("scene not found")
}
// 构建更新数据
updates := make(map[string]interface{})
// 更新背景ID
if req.SceneID != nil {
updates["scene_id"] = req.SceneID
}
// 更新角色列表直接存储ID数组
if req.Characters != nil {
charactersJSON, err := json.Marshal(req.Characters)
if err != nil {
return fmt.Errorf("failed to serialize characters: %w", err)
}
updates["characters"] = charactersJSON
}
// 更新场景信息字段
if req.Location != nil {
updates["location"] = req.Location
}
if req.Time != nil {
updates["time"] = req.Time
}
if req.Action != nil {
updates["action"] = req.Action
}
if req.Dialogue != nil {
updates["dialogue"] = req.Dialogue
}
if req.Description != nil {
updates["description"] = req.Description
}
if req.Duration != nil {
updates["duration"] = *req.Duration
}
if req.ImagePrompt != nil {
updates["image_prompt"] = req.ImagePrompt
}
if req.VideoPrompt != nil {
updates["video_prompt"] = req.VideoPrompt
}
// 执行更新
if len(updates) > 0 {
if err := s.db.Model(&models.Storyboard{}).Where("id = ?", sceneID).Updates(updates).Error; err != nil {
return fmt.Errorf("failed to update scene: %w", err)
}
}
s.log.Infow("Scene updated", "scene_id", sceneID, "updates", updates)
return nil
}
type GenerateSceneImageRequest struct {
SceneID uint `json:"scene_id"`
Prompt string `json:"prompt"`
Model string `json:"model"`
}
func (s *StoryboardCompositionService) GenerateSceneImage(req *GenerateSceneImageRequest) (*models.ImageGeneration, error) {
// 获取场景并验证权限
var scene models.Scene
err := s.db.Where("id = ?", req.SceneID).First(&scene).Error
if err != nil {
return nil, fmt.Errorf("scene not found")
}
// 验证权限通过DramaID查询Drama
var drama models.Drama
if err := s.db.Where("id = ? ", scene.DramaID).First(&drama).Error; err != nil {
return nil, fmt.Errorf("unauthorized")
}
// 构建场景图片生成提示词
prompt := req.Prompt
if prompt == "" {
// 使用场景的Prompt字段
prompt = scene.Prompt
if prompt == "" {
// 如果Prompt为空使用Location和Time构建
prompt = fmt.Sprintf("%s场景%s", scene.Location, scene.Time)
}
s.log.Infow("Using scene prompt", "scene_id", req.SceneID, "prompt", prompt)
}
// 使用imageGen服务直接生成
if s.imageGen != nil {
genReq := &GenerateImageRequest{
SceneID: &req.SceneID,
DramaID: fmt.Sprintf("%d", scene.DramaID),
ImageType: string(models.ImageTypeScene),
Prompt: prompt,
Model: req.Model, // 使用用户指定的模型
Size: "2560x1440", // 3,686,400像素满足doubao模型最低要求16:9比例
Quality: "standard",
}
imageGen, err := s.imageGen.GenerateImage(genReq)
if err != nil {
return nil, fmt.Errorf("failed to generate image: %w", err)
}
// 更新场景的image_url
if imageGen.ImageURL != nil {
scene.ImageURL = imageGen.ImageURL
scene.Status = "generated"
if err := s.db.Save(&scene).Error; err != nil {
s.log.Errorw("Failed to update scene image url", "error", err)
}
}
s.log.Infow("Scene image generation created", "scene_id", req.SceneID, "image_gen_id", imageGen.ID)
return imageGen, nil
}
return nil, fmt.Errorf("image generation service not available")
}
func getStringValue(s *string) string {
if s != nil {
return *s
}
return ""
}

View File

@@ -0,0 +1,741 @@
package services
import (
"strconv"
"fmt"
"strings"
models "github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/logger"
"github.com/drama-generator/backend/pkg/utils"
"gorm.io/gorm"
)
type StoryboardService struct {
db *gorm.DB
aiService *AIService
log *logger.Logger
}
func NewStoryboardService(db *gorm.DB, log *logger.Logger) *StoryboardService {
return &StoryboardService{
db: db,
aiService: NewAIService(db, log),
log: log,
}
}
type Storyboard struct {
ShotNumber int `json:"shot_number"`
Title string `json:"title"` // 镜头标题
ShotType string `json:"shot_type"` // 景别
Angle string `json:"angle"` // 镜头角度
Time string `json:"time"` // 时间
Location string `json:"location"` // 地点
SceneID *uint `json:"scene_id"` // 背景IDAI直接返回可为null
Movement string `json:"movement"` // 运镜
Action string `json:"action"` // 动作
Dialogue string `json:"dialogue"` // 对话/独白
Result string `json:"result"` // 画面结果
Atmosphere string `json:"atmosphere"` // 环境氛围
Emotion string `json:"emotion"` // 情绪
Duration int `json:"duration"` // 时长(秒)
BgmPrompt string `json:"bgm_prompt"` // 配乐提示词
SoundEffect string `json:"sound_effect"` // 音效描述
Characters []uint `json:"characters"` // 涉及的角色ID列表
IsPrimary bool `json:"is_primary"` // 是否主镜
}
type GenerateStoryboardResult struct {
Storyboards []Storyboard `json:"storyboards"`
Total int `json:"total"`
}
func (s *StoryboardService) GenerateStoryboard(episodeID string) (*GenerateStoryboardResult, error) {
// 从数据库获取剧集信息
var episode struct {
ID string
ScriptContent *string
Description *string
DramaID string
}
err := s.db.Table("episodes").
Select("episodes.id, episodes.script_content, episodes.description, episodes.drama_id").
Joins("INNER JOIN dramas ON dramas.id = episodes.drama_id").
Where("episodes.id = ?", episodeID).
First(&episode).Error
if err != nil {
return nil, fmt.Errorf("剧集不存在或无权限访问")
}
// 获取剧本内容
var scriptContent string
if episode.ScriptContent != nil && *episode.ScriptContent != "" {
scriptContent = *episode.ScriptContent
} else if episode.Description != nil && *episode.Description != "" {
scriptContent = *episode.Description
} else {
return nil, fmt.Errorf("剧本内容为空,请先生成剧集内容")
}
// 获取该剧本的所有角色
var characters []models.Character
if err := s.db.Where("drama_id = ?", episode.DramaID).Order("name ASC").Find(&characters).Error; err != nil {
return nil, fmt.Errorf("获取角色列表失败: %w", err)
}
// 构建角色列表字符串包含ID和名称
characterList := "无角色"
if len(characters) > 0 {
var charInfoList []string
for _, char := range characters {
charInfoList = append(charInfoList, fmt.Sprintf(`{"id": %d, "name": "%s"}`, char.ID, char.Name))
}
characterList = fmt.Sprintf("[%s]", strings.Join(charInfoList, ", "))
}
// 获取该项目已提取的场景列表(项目级)
var scenes []models.Scene
if err := s.db.Where("drama_id = ?", episode.DramaID).Order("location ASC, time ASC").Find(&scenes).Error; err != nil {
s.log.Warnw("Failed to get scenes", "error", err)
}
// 构建场景列表字符串包含ID、地点、时间
sceneList := "无场景"
if len(scenes) > 0 {
var sceneInfoList []string
for _, bg := range scenes {
sceneInfoList = append(sceneInfoList, fmt.Sprintf(`{"id": %d, "location": "%s", "time": "%s"}`, bg.ID, bg.Location, bg.Time))
}
sceneList = fmt.Sprintf("[%s]", strings.Join(sceneInfoList, ", "))
}
s.log.Infow("Generating storyboard",
"episode_id", episodeID,
"drama_id", episode.DramaID,
"script_length", len(scriptContent),
"character_count", len(characters),
"characters", characterList,
"scene_count", len(scenes),
"scenes", sceneList)
// 构建分镜头生成提示词
prompt := fmt.Sprintf(`【角色】你是一位资深影视分镜师,精通罗伯特·麦基的镜头拆解理论,擅长构建情绪节奏。
【任务】将小说剧本按**独立动作单元**拆解为分镜头方案。
【本剧可用角色列表】
%s
**重要**在characters字段中只能使用上述角色列表中的角色ID数字不得自创角色或使用其他ID。
【本剧已提取的场景背景列表】
%s
**重要**在scene_id字段中必须从上述背景列表中选择最匹配的背景ID数字。如果没有合适的背景则填null。
【剧本原文】
%s
【分镜要素】每个镜头聚焦单一动作,描述要详尽具体:
1. **镜头标题(title)**用3-5个字概括该镜头的核心内容或情绪
- 例如:"噩梦惊醒"、"对视沉思"、"逃离现场"、"意外发现"
2. **时间**[清晨/午后/深夜/具体时分+详细光线描述]
- 例如:"深夜22:30·月光从破窗斜射入室内形成明暗分界"
3. **地点**[场景完整描述+空间布局+环境细节]
- 例如:"废弃码头仓库·锈蚀货架林立,地面积水反射微弱灯光,墙角堆放腐朽木箱"
4. **镜头设计**
- **景别(shot_type)**[远景/全景/中景/近景/特写]
- **镜头角度(angle)**[平视/仰视/俯视/侧面/背面]
- **运镜方式(movement)**[固定镜头/推镜/拉镜/摇镜/跟镜/移镜]
5. **人物行为****详细动作描述**,包含[谁+具体怎么做+肢体细节+表情状态]
- 例如:"陈峥弯腰用撬棍撬动保险箱门,手臂青筋暴起,眉头紧锁,汗水滑落脸颊"
6. **对话/独白**:提取该镜头中的完整对话或独白内容(如无对话则为空字符串)
7. **画面结果**:动作的即时后果+视觉细节+氛围变化
- 例如:"保险箱门弹开发出金属碰撞声,扬起灰尘在光束中飘散,箱内空无一物只有陈旧报纸,陈峥表情从期待转为失望"
8. **环境氛围**:光线质感+色调+声音环境+整体氛围
- 例如:"昏暗冷色调,只有手电筒光束晃动,远处传来海浪拍打声,压抑沉闷"
9. **配乐提示(bgm_prompt)**:描述该镜头配乐的氛围、节奏、情绪(如无特殊要求则为空字符串)
- 例如:"低沉紧张的弦乐,节奏缓慢,营造压抑氛围"
10. **音效描述(sound_effect)**:描述该镜头的关键音效(如无特殊音效则为空字符串)
- 例如:"金属碰撞声、脚步声、海浪拍打声"
11. **观众情绪**[情绪类型][强度:↑↑↑/↑↑/↑/→/↓] + [落点:悬置/释放/反转]
【输出格式】请以JSON格式输出每个镜头包含以下字段**所有描述性字段都要详细完整**
{
"storyboards": [
{
"shot_number": 1,
"title": "噩梦惊醒",
"shot_type": "全景",
"angle": "俯视45度角",
"time": "深夜22:30·月光从破窗斜射入仓库在地面积水中形成银白色反光墙角昏暗不清",
"location": "废弃码头仓库·锈蚀货架林立,地面积水反射微弱灯光,墙角堆放腐朽木箱和渔网,空气中弥漫潮湿霉味",
"scene_id": 1,
"movement": "固定镜头",
"action": "陈峥弯腰双手握住撬棍用力撬动保险箱门,手臂青筋暴起,眉头紧锁,汗水从额头滑落脸颊,呼吸急促",
"dialogue": "(独白)这么多年了,里面到底藏着什么秘密?",
"result": "保险箱门突然弹开发出刺耳金属声,扬起灰尘在手电筒光束中飘散,箱内空无一物只有几张发黄的旧报纸,陈峥表情从期待转为震惊和失望,瞳孔放大",
"atmosphere": "昏暗冷色调·青灰色为主,只有手电筒光束在黑暗中晃动,远处传来海浪拍打码头的沉闷声,整体氛围压抑沉重",
"emotion": "好奇感↑↑转失望↓(情绪反转)",
"duration": 9,
"bgm_prompt": "低沉紧张的弦乐,节奏缓慢,营造压抑悬疑氛围",
"sound_effect": "金属碰撞声、灰尘飘散声、海浪拍打声",
"characters": [159],
"is_primary": true
},
{
"shot_number": 2,
"title": "对视沉思",
"shot_type": "近景",
"angle": "平视",
"time": "深夜22:31·仓库内光线昏暗只有手电筒光从侧面照亮两人脸部轮廓",
"location": "废弃码头仓库·保险箱旁,背景是模糊的货架剪影",
"scene_id": 1,
"movement": "推镜",
"action": "陈峥缓缓转身,目光与身后的李芳对视,李芳手握手电筒,光束在两人之间晃动,眼神中透露疑惑和警惕",
"dialogue": "陈峥:\"我们被耍了,这里根本没有我们要找的东西。\" 李芳:\"现在怎么办?我们的时间不多了。\"",
"result": "两人站在昏暗中陷入沉思,手电筒光束照在地面形成圆形光斑,背景传来微弱的金属摩擦声,气氛紧张凝重",
"atmosphere": "低调光线·暗部占画面70%,侧面硬光勾勒人物轮廓,冷暖光对比强烈,海风吹过产生呼啸声,营造紧迫感",
"emotion": "紧张感↑↑·警惕↑↑(悬置)",
"duration": 7,
"bgm_prompt": "紧张感逐渐升级的音效,低频持续音",
"sound_effect": "呼吸声、金属摩擦声、海风呼啸声",
"characters": [159, 160],
"is_primary": true
}
]
}
**dialogue字段说明**
- 如果有对话,格式为:角色名:\"台词内容\"
- 多人对话用空格分隔角色A\"...\" 角色B\"...\"
- 独白格式为:(独白)内容
- 旁白格式为:(旁白)内容
- 无对话时填写空字符串:""
- **对话内容必须从原剧本中提取,保持原汁原味**
**角色和背景要求**
- characters字段必须包含该镜头中出现的所有角色ID数字数组格式
- 只提取实际出现的角色ID不出现角色则为空数组[]
- **角色ID必须严格使用【本剧可用角色列表】中的id字段数字不得使用其他ID或自创角色**
- 例如:如果镜头中出现李明(id:159)和王芳(id:160)则characters字段应为[159, 160]
- scene_id字段必须从【本剧已提取的场景背景列表】中选择最匹配的背景ID数字
- 如果列表中没有合适的背景则scene_id填null
- 例如:如果镜头发生在"城市公寓卧室·凌晨"应选择id为1的场景背景
**duration时长估算规则**
- **所有镜头时长必须在4-12秒范围内**,确保节奏合理流畅
- **综合估算原则**:时长由对话内容、动作复杂度、情绪节奏三方面综合决定
**估算步骤**
1. **基础时长**(从场景内容判断):
- 纯对话场景无明显动作基础4秒
- 纯动作场景无对话基础5秒
- 对话+动作混合场景基础6秒
2. **对话调整**(根据台词字数增加时长):
- 无对话:+0秒
- 短对话1-20字+1-2秒
- 中等对话21-50字+2-4秒
- 长对话51字以上+4-6秒
3. **动作调整**(根据动作复杂度增加时长):
- 无动作/静态:+0秒
- 简单动作(表情、转身、拿物品):+0-1秒
- 一般动作(走动、开门、坐下):+1-2秒
- 复杂动作(打斗、追逐、大幅度移动):+2-4秒
- 环境展示(全景扫描、氛围营造):+2-5秒
4. **最终时长** = 基础时长 + 对话调整 + 动作调整确保结果在4-12秒范围内
**示例**
- "陈峥转身离开"简单动作无对话5 + 0 + 1 = 6秒
- "李芳:\"你要去哪里?\""短对话无动作4 + 2 + 0 = 6秒
- "陈峥推开房门,李芳:\"终于找到你了,这些年你去哪了?\""(一般动作+中等对话6 + 3 + 2 = 11秒
- "两人在雨中激烈搏斗,陈峥:\"住手!\""(复杂动作+短对话6 + 2 + 4 = 12秒
**重要**:准确估算每个镜头时长,所有分镜时长之和将作为剧集总时长
**特别要求**
- **【极其重要】必须100%%完整拆解整个剧本,不得省略、跳过、压缩任何剧情内容**
- **从剧本第一个字到最后一个字,逐句逐段转换为分镜**
- **每个对话、每个动作、每个场景转换都必须有对应的分镜**
- 剧本越长分镜数量越多短剧本15-30个中等剧本30-60个长剧本60-100个甚至更多
- **宁可分镜多,也不要遗漏剧情**:一个长场景可拆分为多个连续分镜
- 每个镜头只描述一个主要动作
- 区分主镜is_primary: true和链接镜is_primary: false
- 确保情绪节奏有变化
- **duration字段至关重要**:准确估算每个镜头时长,这将用于计算整集时长
- 严格按照JSON格式输出
**【禁止行为】**
- ❌ 禁止用一个镜头概括多个场景
- ❌ 禁止跳过任何对话或独白
- ❌ 禁止省略剧情发展过程
- ❌ 禁止合并本应分开的镜头
- ✅ 正确做法:剧本有多少内容,就拆解出对应数量的分镜,确保观众看完所有分镜能完整了解剧情
**【关键】场景描述详细度要求**(这些描述将直接用于视频生成模型):
1. **时间(time)字段**必须包含≥15字的详细描述
- ✓ 好例子:"深夜22:30·月光从破窗斜射入仓库在地面积水中形成银白色反光墙角昏暗不清"
- ✗ 差例子:"深夜"
2. **地点(location)字段**必须包含≥20字的详细场景描述
- ✓ 好例子:"废弃码头仓库·锈蚀货架林立,地面积水反射微弱灯光,墙角堆放腐朽木箱和渔网,空气中弥漫潮湿霉味"
- ✗ 差例子:"仓库"
3. **动作(action)字段**必须包含≥25字的详细动作描述包括肢体细节和表情
- ✓ 好例子:"陈峥弯腰双手握住撬棍用力撬动保险箱门,手臂青筋暴起,眉头紧锁,汗水从额头滑落脸颊,呼吸急促"
- ✗ 差例子:"陈峥打开保险箱"
4. **结果(result)字段**必须包含≥25字的详细视觉结果描述
- ✓ 好例子:"保险箱门突然弹开发出刺耳金属声,扬起灰尘在手电筒光束中飘散,箱内空无一物只有几张发黄的旧报纸,陈峥表情从期待转为震惊和失望,瞳孔放大"
- ✗ 差例子:"门打开了"
5. **氛围(atmosphere)字段**必须包含≥20字的环境氛围描述包括光线、色调、声音
- ✓ 好例子:"昏暗冷色调·青灰色为主,只有手电筒光束在黑暗中晃动,远处传来海浪拍打码头的沉闷声,整体氛围压抑沉重"
- ✗ 差例子:"昏暗"
**描述原则**
- 所有描述性字段要像为盲人讲述画面一样详细
- 包含感官细节:视觉、听觉、触觉、嗅觉
- 描述光线、色彩、质感、动态
- 为视频生成AI提供足够的画面构建信息
- 避免抽象词汇,使用具象的视觉化描述`, characterList, sceneList, scriptContent)
// 调用AI服务生成
text, err := s.aiService.GenerateText(prompt, "")
if err != nil {
s.log.Errorw("Failed to generate storyboard", "error", err)
return nil, fmt.Errorf("生成分镜头失败: %w", err)
}
// 解析JSON结果
var result GenerateStoryboardResult
if err := utils.SafeParseAIJSON(text, &result); err != nil {
s.log.Errorw("Failed to parse storyboard JSON", "error", err, "response", text[:min(500, len(text))])
return nil, fmt.Errorf("解析分镜头结果失败: %w", err)
}
result.Total = len(result.Storyboards)
// 计算总时长(所有分镜时长之和)
totalDuration := 0
for _, sb := range result.Storyboards {
totalDuration += sb.Duration
}
s.log.Infow("Storyboard generated",
"episode_id", episodeID,
"count", result.Total,
"total_duration_seconds", totalDuration)
// 保存分镜头到数据库
if err := s.saveStoryboards(episodeID, result.Storyboards); err != nil {
s.log.Errorw("Failed to save storyboards", "error", err)
return nil, fmt.Errorf("保存分镜头失败: %w", err)
}
// 更新剧集时长(秒转分钟,向上取整)
durationMinutes := (totalDuration + 59) / 60
if err := s.db.Model(&models.Episode{}).Where("id = ?", episodeID).Update("duration", durationMinutes).Error; err != nil {
s.log.Errorw("Failed to update episode duration", "error", err)
// 不中断流程,只记录错误
} else {
s.log.Infow("Episode duration updated",
"episode_id", episodeID,
"duration_seconds", totalDuration,
"duration_minutes", durationMinutes)
}
return &result, nil
}
// generateImagePrompt 生成专门用于图片生成的提示词(首帧静态画面)
func (s *StoryboardService) generateImagePrompt(sb Storyboard) string {
var parts []string
// 1. 完整的场景背景描述
if sb.Location != "" {
locationDesc := sb.Location
if sb.Time != "" {
locationDesc += ", " + sb.Time
}
parts = append(parts, locationDesc)
}
// 2. 角色初始静态姿态(去除动作过程,只保留起始状态)
if sb.Action != "" {
initialPose := extractInitialPose(sb.Action)
if initialPose != "" {
parts = append(parts, initialPose)
}
}
// 3. 情绪氛围
if sb.Emotion != "" {
parts = append(parts, sb.Emotion)
}
// 4. 动漫风格
parts = append(parts, "anime style, first frame")
if len(parts) > 0 {
return strings.Join(parts, ", ")
}
return "anime scene"
}
// extractInitialPose 提取初始静态姿态(去除动作过程)
func extractInitialPose(action string) string {
// 去除动作过程关键词,保留初始状态描述
processWords := []string{
"然后", "接着", "接下来", "随后", "紧接着",
"向下", "向上", "向前", "向后", "向左", "向右",
"开始", "继续", "逐渐", "慢慢", "快速", "突然", "猛然",
}
result := action
for _, word := range processWords {
if idx := strings.Index(result, word); idx > 0 {
// 在动作过程词之前截断
result = result[:idx]
break
}
}
// 清理末尾标点
result = strings.TrimRight(result, ",。,. ")
return strings.TrimSpace(result)
}
// extractSimpleLocation 提取简化的场景地点(去除详细描述)
func extractSimpleLocation(location string) string {
// 在"·"符号处截断,只保留主场景名称
if idx := strings.Index(location, "·"); idx > 0 {
return strings.TrimSpace(location[:idx])
}
// 如果有逗号,只保留第一部分
if idx := strings.Index(location, ""); idx > 0 {
return strings.TrimSpace(location[:idx])
}
if idx := strings.Index(location, ","); idx > 0 {
return strings.TrimSpace(location[:idx])
}
// 限制长度不超过15个字符
maxLen := 15
if len(location) > maxLen {
return strings.TrimSpace(location[:maxLen])
}
return strings.TrimSpace(location)
}
// extractSimplePose 提取简单的核心姿态关键词不超过10个字
func extractSimplePose(action string) string {
// 只提取前面最多10个字符作为核心姿态
runes := []rune(action)
maxLen := 10
if len(runes) > maxLen {
// 在标点符号处截断
truncated := runes[:maxLen]
for i := maxLen - 1; i >= 0; i-- {
if truncated[i] == '' || truncated[i] == '。' || truncated[i] == ',' || truncated[i] == '.' {
truncated = runes[:i]
break
}
}
return strings.TrimSpace(string(truncated))
}
return strings.TrimSpace(action)
}
// extractFirstFramePose 从动作描述中提取首帧静态姿态
func extractFirstFramePose(action string) string {
// 去除表示动作过程的关键词,保留初始状态
processWords := []string{
"然后", "接着", "向下", "向前", "走向", "冲向", "转身",
"开始", "继续", "逐渐", "慢慢", "快速", "突然",
}
pose := action
for _, word := range processWords {
// 简单处理:在这些词之前截断
if idx := strings.Index(pose, word); idx > 0 {
pose = pose[:idx]
break
}
}
// 清理末尾标点
pose = strings.TrimRight(pose, ",。,.")
return strings.TrimSpace(pose)
}
// extractCompositionType 从镜头类型中提取构图类型(去除运镜)
func extractCompositionType(shotType string) string {
// 去除运镜相关描述
cameraMovements := []string{
"晃动", "摇晃", "推进", "拉远", "跟随", "环绕",
"运镜", "摄影", "移动", "旋转",
}
comp := shotType
for _, movement := range cameraMovements {
comp = strings.ReplaceAll(comp, movement, "")
}
// 清理多余的标点和空格
comp = strings.ReplaceAll(comp, "··", "·")
comp = strings.ReplaceAll(comp, "·", " ")
comp = strings.TrimSpace(comp)
return comp
}
// generateVideoPrompt 生成专门用于视频生成的提示词(包含运镜和动态元素)
func (s *StoryboardService) generateVideoPrompt(sb Storyboard) string {
var parts []string
// 1. 人物动作
if sb.Action != "" {
parts = append(parts, fmt.Sprintf("Action: %s", sb.Action))
}
// 2. 对话
if sb.Dialogue != "" {
parts = append(parts, fmt.Sprintf("Dialogue: %s", sb.Dialogue))
}
// 3. 镜头运动(视频特有)
if sb.Movement != "" {
parts = append(parts, fmt.Sprintf("Camera movement: %s", sb.Movement))
}
// 4. 镜头类型和角度
if sb.ShotType != "" {
parts = append(parts, fmt.Sprintf("Shot type: %s", sb.ShotType))
}
if sb.Angle != "" {
parts = append(parts, fmt.Sprintf("Camera angle: %s", sb.Angle))
}
// 5. 场景环境
if sb.Location != "" {
locationDesc := sb.Location
if sb.Time != "" {
locationDesc += ", " + sb.Time
}
parts = append(parts, fmt.Sprintf("Scene: %s", locationDesc))
}
// 6. 环境氛围
if sb.Atmosphere != "" {
parts = append(parts, fmt.Sprintf("Atmosphere: %s", sb.Atmosphere))
}
// 7. 情绪和结果
if sb.Emotion != "" {
parts = append(parts, fmt.Sprintf("Mood: %s", sb.Emotion))
}
if sb.Result != "" {
parts = append(parts, fmt.Sprintf("Result: %s", sb.Result))
}
// 8. 音频元素
if sb.BgmPrompt != "" {
parts = append(parts, fmt.Sprintf("BGM: %s", sb.BgmPrompt))
}
if sb.SoundEffect != "" {
parts = append(parts, fmt.Sprintf("Sound effects: %s", sb.SoundEffect))
}
// 9. 视频风格要求
parts = append(parts, "Style: cinematic anime style, smooth camera motion, natural character movement")
if len(parts) > 0 {
return strings.Join(parts, ". ")
}
return "Anime style video scene"
}
func (s *StoryboardService) saveStoryboards(episodeID string, storyboards []Storyboard) error {
// 开启事务
return s.db.Transaction(func(tx *gorm.DB) error {
// 获取该剧集所有的分镜ID
var storyboardIDs []uint
if err := tx.Model(&models.Storyboard{}).
Where("episode_id = ?", episodeID).
Pluck("id", &storyboardIDs).Error; err != nil {
return err
}
// 如果有分镜先清理关联的image_generations的storyboard_id
if len(storyboardIDs) > 0 {
if err := tx.Model(&models.ImageGeneration{}).
Where("storyboard_id IN ?", storyboardIDs).
Update("storyboard_id", nil).Error; err != nil {
return err
}
}
// 删除该剧集已有的分镜头
if err := tx.Where("episode_id = ?", episodeID).Delete(&models.Storyboard{}).Error; err != nil {
return err
}
// 注意:不删除背景,因为背景是在分镜拆解前就提取好的
// AI会直接返回scene_id不需要在这里做字符串匹配
// 保存新的分镜头
for _, sb := range storyboards {
// 构建描述信息,包含对话
description := fmt.Sprintf("【镜头类型】%s\n【运镜】%s\n【动作】%s\n【对话】%s\n【结果】%s\n【情绪】%s",
sb.ShotType, sb.Movement, sb.Action, sb.Dialogue, sb.Result, sb.Emotion)
// 生成两种专用提示词
imagePrompt := s.generateImagePrompt(sb) // 专用于图片生成
videoPrompt := s.generateVideoPrompt(sb) // 专用于视频生成
// 处理 dialogue 字段
var dialoguePtr *string
if sb.Dialogue != "" {
dialoguePtr = &sb.Dialogue
}
// 使用AI直接返回的SceneID
if sb.SceneID != nil {
s.log.Infow("Background ID from AI",
"shot_number", sb.ShotNumber,
"scene_id", *sb.SceneID)
}
epID, _ := strconv.ParseUint(episodeID, 10, 32)
// 处理 title 字段
var titlePtr *string
if sb.Title != "" {
titlePtr = &sb.Title
}
// 处理shot_type、angle、movement字段
var shotTypePtr, anglePtr, movementPtr *string
if sb.ShotType != "" {
shotTypePtr = &sb.ShotType
}
if sb.Angle != "" {
anglePtr = &sb.Angle
}
if sb.Movement != "" {
movementPtr = &sb.Movement
}
// 处理bgm_prompt、sound_effect字段
var bgmPromptPtr, soundEffectPtr *string
if sb.BgmPrompt != "" {
bgmPromptPtr = &sb.BgmPrompt
}
if sb.SoundEffect != "" {
soundEffectPtr = &sb.SoundEffect
}
// 处理result、atmosphere字段
var resultPtr, atmospherePtr *string
if sb.Result != "" {
resultPtr = &sb.Result
}
if sb.Atmosphere != "" {
atmospherePtr = &sb.Atmosphere
}
scene := models.Storyboard{
EpisodeID: uint(epID),
SceneID: sb.SceneID,
StoryboardNumber: sb.ShotNumber,
Title: titlePtr,
Location: &sb.Location,
Time: &sb.Time,
ShotType: shotTypePtr,
Angle: anglePtr,
Movement: movementPtr,
Description: &description,
Action: &sb.Action,
Result: resultPtr,
Atmosphere: atmospherePtr,
Dialogue: dialoguePtr,
ImagePrompt: &imagePrompt,
VideoPrompt: &videoPrompt,
BgmPrompt: bgmPromptPtr,
SoundEffect: soundEffectPtr,
Duration: sb.Duration,
}
if err := tx.Create(&scene).Error; err != nil {
s.log.Errorw("Failed to create scene", "error", err, "shot_number", sb.ShotNumber)
return err
}
// 关联角色
if len(sb.Characters) > 0 {
var characters []models.Character
if err := tx.Where("id IN ?", sb.Characters).Find(&characters).Error; err != nil {
s.log.Warnw("Failed to load characters for association", "error", err, "character_ids", sb.Characters)
} else if len(characters) > 0 {
if err := tx.Model(&scene).Association("Characters").Append(characters); err != nil {
s.log.Warnw("Failed to associate characters", "error", err, "shot_number", sb.ShotNumber)
} else {
s.log.Infow("Characters associated successfully",
"shot_number", sb.ShotNumber,
"character_ids", sb.Characters,
"count", len(characters))
}
}
}
}
s.log.Infow("Storyboards saved successfully", "episode_id", episodeID, "count", len(storyboards))
return nil
})
}
// UpdateStoryboardCharacters 更新分镜的角色关联
func (s *StoryboardService) UpdateStoryboardCharacters(storyboardID string, characterIDs []uint) error {
// 查找分镜
var storyboard models.Storyboard
if err := s.db.First(&storyboard, storyboardID).Error; err != nil {
return fmt.Errorf("storyboard not found: %w", err)
}
// 清除现有的角色关联
if err := s.db.Model(&storyboard).Association("Characters").Clear(); err != nil {
return fmt.Errorf("failed to clear characters: %w", err)
}
// 如果有新的角色ID加载并关联
if len(characterIDs) > 0 {
var characters []models.Character
if err := s.db.Where("id IN ?", characterIDs).Find(&characters).Error; err != nil {
return fmt.Errorf("failed to find characters: %w", err)
}
if err := s.db.Model(&storyboard).Association("Characters").Append(characters); err != nil {
return fmt.Errorf("failed to associate characters: %w", err)
}
}
s.log.Infow("Storyboard characters updated", "storyboard_id", storyboardID, "character_count", len(characterIDs))
return nil
}
func min(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,138 @@
package services
import (
"fmt"
"github.com/drama-generator/backend/domain/models"
)
// UpdateStoryboard 更新分镜的所有字段,并重新生成提示词
func (s *StoryboardService) UpdateStoryboard(storyboardID string, updates map[string]interface{}) error {
// 查找分镜
var storyboard models.Storyboard
if err := s.db.First(&storyboard, storyboardID).Error; err != nil {
return fmt.Errorf("storyboard not found: %w", err)
}
// 构建用于重新生成提示词的Storyboard结构
sb := Storyboard{
ShotNumber: storyboard.StoryboardNumber,
}
// 从updates中提取字段并更新
updateData := make(map[string]interface{})
if val, ok := updates["title"].(string); ok && val != "" {
updateData["title"] = val
sb.Title = val
}
if val, ok := updates["shot_type"].(string); ok && val != "" {
updateData["shot_type"] = val
sb.ShotType = val
}
if val, ok := updates["angle"].(string); ok && val != "" {
updateData["angle"] = val
sb.Angle = val
}
if val, ok := updates["movement"].(string); ok && val != "" {
updateData["movement"] = val
sb.Movement = val
}
if val, ok := updates["location"].(string); ok && val != "" {
updateData["location"] = val
sb.Location = val
}
if val, ok := updates["time"].(string); ok && val != "" {
updateData["time"] = val
sb.Time = val
}
if val, ok := updates["action"].(string); ok && val != "" {
updateData["action"] = val
sb.Action = val
}
if val, ok := updates["dialogue"].(string); ok && val != "" {
updateData["dialogue"] = val
sb.Dialogue = val
}
if val, ok := updates["result"].(string); ok && val != "" {
updateData["result"] = val
sb.Result = val
}
if val, ok := updates["atmosphere"].(string); ok && val != "" {
updateData["atmosphere"] = val
sb.Atmosphere = val
}
if val, ok := updates["description"].(string); ok && val != "" {
updateData["description"] = val
}
if val, ok := updates["bgm_prompt"].(string); ok && val != "" {
updateData["bgm_prompt"] = val
sb.BgmPrompt = val
}
if val, ok := updates["sound_effect"].(string); ok && val != "" {
updateData["sound_effect"] = val
sb.SoundEffect = val
}
if val, ok := updates["duration"].(float64); ok {
updateData["duration"] = int(val)
sb.Duration = int(val)
}
// 使用当前数据库值填充缺失字段(用于生成提示词)
if sb.Title == "" && storyboard.Title != nil {
sb.Title = *storyboard.Title
}
if sb.ShotType == "" && storyboard.ShotType != nil {
sb.ShotType = *storyboard.ShotType
}
if sb.Angle == "" && storyboard.Angle != nil {
sb.Angle = *storyboard.Angle
}
if sb.Movement == "" && storyboard.Movement != nil {
sb.Movement = *storyboard.Movement
}
if sb.Location == "" && storyboard.Location != nil {
sb.Location = *storyboard.Location
}
if sb.Time == "" && storyboard.Time != nil {
sb.Time = *storyboard.Time
}
if sb.Action == "" && storyboard.Action != nil {
sb.Action = *storyboard.Action
}
if sb.Dialogue == "" && storyboard.Dialogue != nil {
sb.Dialogue = *storyboard.Dialogue
}
if sb.Result == "" && storyboard.Result != nil {
sb.Result = *storyboard.Result
}
if sb.Atmosphere == "" && storyboard.Atmosphere != nil {
sb.Atmosphere = *storyboard.Atmosphere
}
if sb.BgmPrompt == "" && storyboard.BgmPrompt != nil {
sb.BgmPrompt = *storyboard.BgmPrompt
}
if sb.SoundEffect == "" && storyboard.SoundEffect != nil {
sb.SoundEffect = *storyboard.SoundEffect
}
if sb.Duration == 0 {
sb.Duration = storyboard.Duration
}
// 只重新生成video_prompt
// image_prompt不自动更新因为可能对应多张已生成的帧图片
videoPrompt := s.generateVideoPrompt(sb)
updateData["video_prompt"] = videoPrompt
// 更新数据库
if err := s.db.Model(&storyboard).Updates(updateData).Error; err != nil {
return fmt.Errorf("failed to update storyboard: %w", err)
}
s.log.Infow("Storyboard updated successfully",
"storyboard_id", storyboardID,
"fields_updated", len(updateData))
return nil
}

View File

@@ -0,0 +1,113 @@
package services
import (
"encoding/json"
"fmt"
"time"
"github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/logger"
"github.com/google/uuid"
"gorm.io/gorm"
)
type TaskService struct {
db *gorm.DB
log *logger.Logger
}
func NewTaskService(db *gorm.DB, log *logger.Logger) *TaskService {
return &TaskService{
db: db,
log: log,
}
}
// CreateTask 创建新任务
func (s *TaskService) CreateTask(taskType, resourceID string) (*models.AsyncTask, error) {
task := &models.AsyncTask{
ID: uuid.New().String(),
Type: taskType,
Status: "pending",
Progress: 0,
ResourceID: resourceID,
}
if err := s.db.Create(task).Error; err != nil {
return nil, fmt.Errorf("failed to create task: %w", err)
}
return task, nil
}
// UpdateTaskStatus 更新任务状态
func (s *TaskService) UpdateTaskStatus(taskID, status string, progress int, message string) error {
updates := map[string]interface{}{
"status": status,
"progress": progress,
"message": message,
"updated_at": time.Now(),
}
if status == "completed" || status == "failed" {
now := time.Now()
updates["completed_at"] = &now
}
return s.db.Model(&models.AsyncTask{}).
Where("id = ?", taskID).
Updates(updates).Error
}
// UpdateTaskError 更新任务错误
func (s *TaskService) UpdateTaskError(taskID string, err error) error {
now := time.Now()
return s.db.Model(&models.AsyncTask{}).
Where("id = ?", taskID).
Updates(map[string]interface{}{
"status": "failed",
"error": err.Error(),
"progress": 0,
"completed_at": &now,
"updated_at": time.Now(),
}).Error
}
// UpdateTaskResult 更新任务结果
func (s *TaskService) UpdateTaskResult(taskID string, result interface{}) error {
resultJSON, err := json.Marshal(result)
if err != nil {
return fmt.Errorf("failed to marshal result: %w", err)
}
now := time.Now()
return s.db.Model(&models.AsyncTask{}).
Where("id = ?", taskID).
Updates(map[string]interface{}{
"status": "completed",
"progress": 100,
"result": string(resultJSON),
"completed_at": &now,
"updated_at": time.Now(),
}).Error
}
// GetTask 获取任务信息
func (s *TaskService) GetTask(taskID string) (*models.AsyncTask, error) {
var task models.AsyncTask
if err := s.db.Where("id = ?", taskID).First(&task).Error; err != nil {
return nil, err
}
return &task, nil
}
// GetTasksByResource 获取资源相关的所有任务
func (s *TaskService) GetTasksByResource(resourceID string) ([]*models.AsyncTask, error) {
var tasks []*models.AsyncTask
if err := s.db.Where("resource_id = ?", resourceID).
Order("created_at DESC").
Find(&tasks).Error; err != nil {
return nil, err
}
return tasks, nil
}

View File

@@ -0,0 +1,109 @@
package services
import (
"fmt"
"io"
"os"
"path/filepath"
"time"
"github.com/drama-generator/backend/pkg/config"
"github.com/drama-generator/backend/pkg/logger"
"github.com/google/uuid"
)
type UploadService struct {
storagePath string
baseURL string
log *logger.Logger
}
func NewUploadService(cfg *config.Config, log *logger.Logger) (*UploadService, error) {
// 确保存储目录存在
if err := os.MkdirAll(cfg.Storage.LocalPath, 0755); err != nil {
return nil, fmt.Errorf("failed to create storage directory: %w", err)
}
return &UploadService{
storagePath: cfg.Storage.LocalPath,
baseURL: cfg.Storage.BaseURL,
log: log,
}, nil
}
// UploadFile 上传文件到本地存储
func (s *UploadService) UploadFile(file io.Reader, fileName, contentType string, category string) (string, error) {
// 创建分类目录
categoryPath := filepath.Join(s.storagePath, category)
if err := os.MkdirAll(categoryPath, 0755); err != nil {
return "", fmt.Errorf("failed to create category directory: %w", err)
}
// 生成唯一文件名
ext := filepath.Ext(fileName)
uniqueID := uuid.New().String()
timestamp := time.Now().Format("20060102_150405")
newFileName := fmt.Sprintf("%s_%s%s", timestamp, uniqueID, ext)
filePath := filepath.Join(categoryPath, newFileName)
// 创建文件
dst, err := os.Create(filePath)
if err != nil {
s.log.Errorw("Failed to create file", "error", err, "path", filePath)
return "", fmt.Errorf("创建文件失败: %w", err)
}
defer dst.Close()
// 写入文件
if _, err := io.Copy(dst, file); err != nil {
s.log.Errorw("Failed to write file", "error", err, "path", filePath)
return "", fmt.Errorf("写入文件失败: %w", err)
}
// 构建访问URL
fileURL := fmt.Sprintf("%s/%s/%s", s.baseURL, category, newFileName)
s.log.Infow("File uploaded successfully", "path", filePath, "url", fileURL)
return fileURL, nil
}
// UploadCharacterImage 上传角色图片
func (s *UploadService) UploadCharacterImage(file io.Reader, fileName, contentType string) (string, error) {
return s.UploadFile(file, fileName, contentType, "characters")
}
// DeleteFile 删除本地文件
func (s *UploadService) DeleteFile(fileURL string) error {
// 从URL中提取相对路径
// URL格式: http://localhost:8080/static/characters/20060102_150405_uuid.jpg
relPath := s.extractRelativePathFromURL(fileURL)
if relPath == "" {
return fmt.Errorf("invalid file URL")
}
filePath := filepath.Join(s.storagePath, relPath)
err := os.Remove(filePath)
if err != nil {
s.log.Errorw("Failed to delete file", "error", err, "path", filePath)
return fmt.Errorf("删除文件失败: %w", err)
}
s.log.Infow("File deleted successfully", "path", filePath)
return nil
}
// extractRelativePathFromURL 从URL中提取相对路径
func (s *UploadService) extractRelativePathFromURL(fileURL string) string {
// 从baseURL后面提取路径
// 例如: http://localhost:8080/static/characters/xxx.jpg -> characters/xxx.jpg
if len(fileURL) <= len(s.baseURL) {
return ""
}
return fileURL[len(s.baseURL)+1:] // +1 for the '/'
}
// GetPresignedURL 本地存储不需要预签名URL直接返回原URL
func (s *UploadService) GetPresignedURL(objectName string, expiry time.Duration) (string, error) {
// 本地存储通过静态文件服务直接访问,不需要预签名
return fmt.Sprintf("%s/%s", s.baseURL, objectName), nil
}

View File

@@ -0,0 +1,566 @@
package services
import (
"encoding/json"
"fmt"
"strconv"
"time"
models "github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/infrastructure/storage"
"github.com/drama-generator/backend/pkg/logger"
"github.com/drama-generator/backend/pkg/video"
"gorm.io/gorm"
)
type VideoGenerationService struct {
db *gorm.DB
transferService *ResourceTransferService
log *logger.Logger
localStorage *storage.LocalStorage
aiService *AIService
}
func NewVideoGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, aiService *AIService, log *logger.Logger) *VideoGenerationService {
service := &VideoGenerationService{
db: db,
localStorage: localStorage,
transferService: transferService,
aiService: aiService,
log: log,
}
go service.RecoverPendingTasks()
return service
}
type GenerateVideoRequest struct {
StoryboardID *uint `json:"storyboard_id"`
DramaID string `json:"drama_id" binding:"required"`
ImageGenID *uint `json:"image_gen_id"`
// 参考图模式single, first_last, multiple, none
ReferenceMode string `json:"reference_mode"`
// 单图模式
ImageURL string `json:"image_url"`
// 首尾帧模式
FirstFrameURL *string `json:"first_frame_url"`
LastFrameURL *string `json:"last_frame_url"`
// 多图模式
ReferenceImageURLs []string `json:"reference_image_urls"`
Prompt string `json:"prompt" binding:"required,min=5,max=2000"`
Provider string `json:"provider"`
Model string `json:"model"`
Duration *int `json:"duration"`
FPS *int `json:"fps"`
AspectRatio *string `json:"aspect_ratio"`
Style *string `json:"style"`
MotionLevel *int `json:"motion_level"`
CameraMotion *string `json:"camera_motion"`
Seed *int64 `json:"seed"`
}
func (s *VideoGenerationService) GenerateVideo(request *GenerateVideoRequest) (*models.VideoGeneration, error) {
if request.StoryboardID != nil {
var storyboard models.Storyboard
if err := s.db.Preload("Episode").Where("id = ?", *request.StoryboardID).First(&storyboard).Error; err != nil {
return nil, fmt.Errorf("storyboard not found")
}
if fmt.Sprintf("%d", storyboard.Episode.DramaID) != request.DramaID {
return nil, fmt.Errorf("storyboard does not belong to drama")
}
}
if request.ImageGenID != nil {
var imageGen models.ImageGeneration
if err := s.db.Where("id = ?", *request.ImageGenID).First(&imageGen).Error; err != nil {
return nil, fmt.Errorf("image generation not found")
}
}
provider := request.Provider
if provider == "" {
provider = "doubao"
}
dramaID, _ := strconv.ParseUint(request.DramaID, 10, 32)
videoGen := &models.VideoGeneration{
StoryboardID: request.StoryboardID,
DramaID: uint(dramaID),
ImageGenID: request.ImageGenID,
Provider: provider,
Prompt: request.Prompt,
Model: request.Model,
Duration: request.Duration,
FPS: request.FPS,
AspectRatio: request.AspectRatio,
Style: request.Style,
MotionLevel: request.MotionLevel,
CameraMotion: request.CameraMotion,
Seed: request.Seed,
Status: models.VideoStatusPending,
}
// 根据参考图模式处理不同的参数
if request.ReferenceMode != "" {
videoGen.ReferenceMode = &request.ReferenceMode
}
switch request.ReferenceMode {
case "single":
// 单图模式
if request.ImageURL != "" {
videoGen.ImageURL = &request.ImageURL
}
case "first_last":
// 首尾帧模式
if request.FirstFrameURL != nil {
videoGen.FirstFrameURL = request.FirstFrameURL
}
if request.LastFrameURL != nil {
videoGen.LastFrameURL = request.LastFrameURL
}
case "multiple":
// 多图模式
if len(request.ReferenceImageURLs) > 0 {
referenceImagesJSON, err := json.Marshal(request.ReferenceImageURLs)
if err == nil {
referenceImagesStr := string(referenceImagesJSON)
videoGen.ReferenceImageURLs = &referenceImagesStr
}
}
case "none":
// 无参考图,纯文本生成
default:
// 向后兼容:如果没有指定模式,根据提供的参数自动判断
if request.ImageURL != "" {
videoGen.ImageURL = &request.ImageURL
mode := "single"
videoGen.ReferenceMode = &mode
} else if request.FirstFrameURL != nil || request.LastFrameURL != nil {
videoGen.FirstFrameURL = request.FirstFrameURL
videoGen.LastFrameURL = request.LastFrameURL
mode := "first_last"
videoGen.ReferenceMode = &mode
} else if len(request.ReferenceImageURLs) > 0 {
referenceImagesJSON, err := json.Marshal(request.ReferenceImageURLs)
if err == nil {
referenceImagesStr := string(referenceImagesJSON)
videoGen.ReferenceImageURLs = &referenceImagesStr
mode := "multiple"
videoGen.ReferenceMode = &mode
}
}
}
if err := s.db.Create(videoGen).Error; err != nil {
return nil, fmt.Errorf("failed to create record: %w", err)
}
go s.ProcessVideoGeneration(videoGen.ID)
return videoGen, nil
}
func (s *VideoGenerationService) ProcessVideoGeneration(videoGenID uint) {
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, videoGenID).Error; err != nil {
s.log.Errorw("Failed to load video generation", "error", err, "id", videoGenID)
return
}
s.db.Model(&videoGen).Update("status", models.VideoStatusProcessing)
client, err := s.getVideoClient(videoGen.Provider, videoGen.Model)
if err != nil {
s.log.Errorw("Failed to get video client", "error", err, "provider", videoGen.Provider, "model", videoGen.Model)
s.updateVideoGenError(videoGenID, err.Error())
return
}
s.log.Infow("Starting video generation", "id", videoGenID, "prompt", videoGen.Prompt, "provider", videoGen.Provider)
var opts []video.VideoOption
if videoGen.Model != "" {
opts = append(opts, video.WithModel(videoGen.Model))
}
if videoGen.Duration != nil {
opts = append(opts, video.WithDuration(*videoGen.Duration))
}
if videoGen.FPS != nil {
opts = append(opts, video.WithFPS(*videoGen.FPS))
}
if videoGen.AspectRatio != nil {
opts = append(opts, video.WithAspectRatio(*videoGen.AspectRatio))
}
if videoGen.Style != nil {
opts = append(opts, video.WithStyle(*videoGen.Style))
}
if videoGen.MotionLevel != nil {
opts = append(opts, video.WithMotionLevel(*videoGen.MotionLevel))
}
if videoGen.CameraMotion != nil {
opts = append(opts, video.WithCameraMotion(*videoGen.CameraMotion))
}
if videoGen.Seed != nil {
opts = append(opts, video.WithSeed(*videoGen.Seed))
}
// 根据参考图模式添加相应的选项
if videoGen.ReferenceMode != nil {
switch *videoGen.ReferenceMode {
case "first_last":
// 首尾帧模式
if videoGen.FirstFrameURL != nil {
opts = append(opts, video.WithFirstFrame(*videoGen.FirstFrameURL))
}
if videoGen.LastFrameURL != nil {
opts = append(opts, video.WithLastFrame(*videoGen.LastFrameURL))
}
case "multiple":
// 多图模式
if videoGen.ReferenceImageURLs != nil {
var imageURLs []string
if err := json.Unmarshal([]byte(*videoGen.ReferenceImageURLs), &imageURLs); err == nil {
opts = append(opts, video.WithReferenceImages(imageURLs))
}
}
}
}
// 构造imageURL参数单图模式使用其他模式传空字符串
imageURL := ""
if videoGen.ImageURL != nil {
imageURL = *videoGen.ImageURL
}
result, err := client.GenerateVideo(imageURL, videoGen.Prompt, opts...)
if err != nil {
s.log.Errorw("Video generation API call failed", "error", err, "id", videoGenID)
s.updateVideoGenError(videoGenID, err.Error())
return
}
if result.TaskID != "" {
s.db.Model(&videoGen).Updates(map[string]interface{}{
"task_id": result.TaskID,
"status": models.VideoStatusProcessing,
})
go s.pollTaskStatus(videoGenID, result.TaskID, videoGen.Provider, videoGen.Model)
return
}
if result.VideoURL != "" {
s.completeVideoGeneration(videoGenID, result.VideoURL, &result.Duration, &result.Width, &result.Height, nil)
return
}
s.updateVideoGenError(videoGenID, "no task ID or video URL returned")
}
func (s *VideoGenerationService) pollTaskStatus(videoGenID uint, taskID string, provider string, model string) {
client, err := s.getVideoClient(provider, model)
if err != nil {
s.log.Errorw("Failed to get video client for polling", "error", err)
s.updateVideoGenError(videoGenID, "failed to get video client")
return
}
maxAttempts := 300
interval := 10 * time.Second
for attempt := 0; attempt < maxAttempts; attempt++ {
time.Sleep(interval)
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, videoGenID).Error; err != nil {
s.log.Errorw("Failed to load video generation", "error", err, "id", videoGenID)
return
}
if videoGen.Status != models.VideoStatusProcessing {
s.log.Infow("Video generation status changed, stopping poll", "id", videoGenID, "status", videoGen.Status)
return
}
result, err := client.GetTaskStatus(taskID)
if err != nil {
s.log.Errorw("Failed to get task status", "error", err, "task_id", taskID)
continue
}
if result.Completed {
if result.VideoURL != "" {
s.completeVideoGeneration(videoGenID, result.VideoURL, &result.Duration, &result.Width, &result.Height, nil)
return
}
s.updateVideoGenError(videoGenID, "task completed but no video URL")
return
}
if result.Error != "" {
s.updateVideoGenError(videoGenID, result.Error)
return
}
s.log.Infow("Video generation in progress", "id", videoGenID, "attempt", attempt+1)
}
s.updateVideoGenError(videoGenID, "polling timeout")
}
func (s *VideoGenerationService) completeVideoGeneration(videoGenID uint, videoURL string, duration *int, width *int, height *int, firstFrameURL *string) {
// 下载视频到本地存储(仅用于缓存,不更新数据库)
if s.localStorage != nil && videoURL != "" {
_, err := s.localStorage.DownloadFromURL(videoURL, "videos")
if err != nil {
s.log.Warnw("Failed to download video to local storage",
"error", err,
"id", videoGenID,
"original_url", videoURL)
} else {
s.log.Infow("Video downloaded to local storage for caching",
"id", videoGenID,
"original_url", videoURL)
}
}
// 下载首帧图片到本地存储(仅用于缓存,不更新数据库)
if firstFrameURL != nil && *firstFrameURL != "" && s.localStorage != nil {
_, err := s.localStorage.DownloadFromURL(*firstFrameURL, "video_frames")
if err != nil {
s.log.Warnw("Failed to download first frame to local storage",
"error", err,
"id", videoGenID,
"original_url", *firstFrameURL)
} else {
s.log.Infow("First frame downloaded to local storage for caching",
"id", videoGenID,
"original_url", *firstFrameURL)
}
}
// 数据库中保持使用原始URL
updates := map[string]interface{}{
"status": models.VideoStatusCompleted,
"video_url": videoURL,
}
if duration != nil {
updates["duration"] = *duration
}
if width != nil {
updates["width"] = *width
}
if height != nil {
updates["height"] = *height
}
if firstFrameURL != nil {
updates["first_frame_url"] = *firstFrameURL
}
if err := s.db.Model(&models.VideoGeneration{}).Where("id = ?", videoGenID).Updates(updates).Error; err != nil {
s.log.Errorw("Failed to update video generation", "error", err, "id", videoGenID)
return
}
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, videoGenID).Error; err == nil {
if videoGen.StoryboardID != nil {
if err := s.db.Model(&models.Storyboard{}).Where("id = ?", *videoGen.StoryboardID).Update("video_url", videoURL).Error; err != nil {
s.log.Warnw("Failed to update storyboard video_url", "storyboard_id", *videoGen.StoryboardID, "error", err)
}
}
}
s.log.Infow("Video generation completed", "id", videoGenID, "url", videoURL)
}
func (s *VideoGenerationService) updateVideoGenError(videoGenID uint, errorMsg string) {
if err := s.db.Model(&models.VideoGeneration{}).Where("id = ?", videoGenID).Updates(map[string]interface{}{
"status": models.VideoStatusFailed,
"error_msg": errorMsg,
}).Error; err != nil {
s.log.Errorw("Failed to update video generation error", "error", err, "id", videoGenID)
}
}
func (s *VideoGenerationService) getVideoClient(provider string, modelName string) (video.VideoClient, error) {
// 根据模型名称获取AI配置
var config *models.AIServiceConfig
var err error
if modelName != "" {
config, err = s.aiService.GetConfigForModel("video", modelName)
if err != nil {
s.log.Warnw("Failed to get config for model, using default", "model", modelName, "error", err)
config, err = s.aiService.GetDefaultConfig("video")
if err != nil {
return nil, fmt.Errorf("no video AI config found: %w", err)
}
}
} else {
config, err = s.aiService.GetDefaultConfig("video")
if err != nil {
return nil, fmt.Errorf("no video AI config found: %w", err)
}
}
// 使用配置中的信息创建客户端
baseURL := config.BaseURL
apiKey := config.APIKey
model := modelName
if model == "" && len(config.Model) > 0 {
model = config.Model[0]
}
// 根据配置中的 provider 创建对应的客户端
var endpoint string
var queryEndpoint string
switch config.Provider {
case "chatfire":
endpoint = "/video/generations"
queryEndpoint = "/video/task/{taskId}"
return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
case "doubao", "volcengine", "volces":
endpoint = "/contents/generations/tasks"
queryEndpoint = "/contents/generations/tasks/{taskId}"
return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
case "openai":
// OpenAI Sora 使用 /v1/videos 端点
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
case "runway":
return video.NewRunwayClient(baseURL, apiKey, model), nil
case "pika":
return video.NewPikaClient(baseURL, apiKey, model), nil
case "minimax":
return video.NewMinimaxClient(baseURL, apiKey, model), nil
default:
return nil, fmt.Errorf("unsupported video provider: %s", provider)
}
}
func (s *VideoGenerationService) RecoverPendingTasks() {
var pendingVideos []models.VideoGeneration
if err := s.db.Where("status = ? AND task_id != ''", models.VideoStatusProcessing).Find(&pendingVideos).Error; err != nil {
s.log.Errorw("Failed to load pending video tasks", "error", err)
return
}
s.log.Infow("Recovering pending video generation tasks", "count", len(pendingVideos))
for _, videoGen := range pendingVideos {
go s.pollTaskStatus(videoGen.ID, *videoGen.TaskID, videoGen.Provider, videoGen.Model)
}
}
func (s *VideoGenerationService) GetVideoGeneration(id uint) (*models.VideoGeneration, error) {
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, id).Error; err != nil {
return nil, err
}
return &videoGen, nil
}
func (s *VideoGenerationService) ListVideoGenerations(dramaID *uint, storyboardID *uint, status string, limit int, offset int) ([]*models.VideoGeneration, int64, error) {
var videos []*models.VideoGeneration
var total int64
query := s.db.Model(&models.VideoGeneration{})
if dramaID != nil {
query = query.Where("drama_id = ?", *dramaID)
}
if storyboardID != nil {
query = query.Where("storyboard_id = ?", *storyboardID)
}
if status != "" {
query = query.Where("status = ?", status)
}
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Limit(limit).Offset(offset).Find(&videos).Error; err != nil {
return nil, 0, err
}
return videos, total, nil
}
func (s *VideoGenerationService) GenerateVideoFromImage(imageGenID uint) (*models.VideoGeneration, error) {
var imageGen models.ImageGeneration
if err := s.db.First(&imageGen, imageGenID).Error; err != nil {
return nil, fmt.Errorf("image generation not found")
}
if imageGen.Status != models.ImageStatusCompleted || imageGen.ImageURL == nil {
return nil, fmt.Errorf("image is not ready")
}
// 获取关联的Storyboard以获取时长
var duration *int
if imageGen.StoryboardID != nil {
var storyboard models.Storyboard
if err := s.db.Where("id = ?", *imageGen.StoryboardID).First(&storyboard).Error; err == nil {
duration = &storyboard.Duration
s.log.Infow("Using storyboard duration for video generation",
"storyboard_id", *imageGen.StoryboardID,
"duration", storyboard.Duration)
}
}
req := &GenerateVideoRequest{
DramaID: fmt.Sprintf("%d", imageGen.DramaID),
StoryboardID: imageGen.StoryboardID,
ImageGenID: &imageGenID,
ImageURL: *imageGen.ImageURL,
Prompt: imageGen.Prompt,
Provider: "doubao",
Duration: duration,
}
return s.GenerateVideo(req)
}
func (s *VideoGenerationService) BatchGenerateVideosForEpisode(episodeID string) ([]*models.VideoGeneration, error) {
var episode models.Episode
if err := s.db.Preload("Storyboards").Where("id = ?", episodeID).First(&episode).Error; err != nil {
return nil, fmt.Errorf("episode not found")
}
var results []*models.VideoGeneration
for _, storyboard := range episode.Storyboards {
if storyboard.ImagePrompt == nil {
continue
}
var imageGen models.ImageGeneration
if err := s.db.Where("storyboard_id = ? AND status = ?", storyboard.ID, models.ImageStatusCompleted).
Order("created_at DESC").First(&imageGen).Error; err != nil {
s.log.Warnw("No completed image for storyboard", "storyboard_id", storyboard.ID)
continue
}
videoGen, err := s.GenerateVideoFromImage(imageGen.ID)
if err != nil {
s.log.Errorw("Failed to generate video", "storyboard_id", storyboard.ID, "error", err)
continue
}
results = append(results, videoGen)
}
return results, nil
}
func (s *VideoGenerationService) DeleteVideoGeneration(id uint) error {
return s.db.Delete(&models.VideoGeneration{}, id).Error
}

View File

@@ -0,0 +1,557 @@
package services
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
"strconv"
"time"
models "github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/infrastructure/external/ffmpeg"
"github.com/drama-generator/backend/pkg/logger"
"github.com/drama-generator/backend/pkg/video"
"gorm.io/gorm"
)
type VideoMergeService struct {
db *gorm.DB
aiService *AIService
transferService *ResourceTransferService
ffmpeg *ffmpeg.FFmpeg
storagePath string
baseURL string
log *logger.Logger
}
func NewVideoMergeService(db *gorm.DB, transferService *ResourceTransferService, storagePath, baseURL string, log *logger.Logger) *VideoMergeService {
return &VideoMergeService{
db: db,
aiService: NewAIService(db, log),
transferService: transferService,
ffmpeg: ffmpeg.NewFFmpeg(log),
storagePath: storagePath,
baseURL: baseURL,
log: log,
}
}
type MergeVideoRequest struct {
EpisodeID string `json:"episode_id" binding:"required"`
DramaID string `json:"drama_id" binding:"required"`
Title string `json:"title"`
Scenes []models.SceneClip `json:"scenes" binding:"required,min=1"`
Provider string `json:"provider"`
Model string `json:"model"`
}
func (s *VideoMergeService) MergeVideos(req *MergeVideoRequest) (*models.VideoMerge, error) {
// 验证episode权限
var episode models.Episode
if err := s.db.Preload("Drama").Where("id = ?", req.EpisodeID).First(&episode).Error; err != nil {
return nil, fmt.Errorf("episode not found")
}
// 验证所有场景都有视频
for i, scene := range req.Scenes {
if scene.VideoURL == "" {
return nil, fmt.Errorf("scene %d has no video", i+1)
}
}
provider := req.Provider
if provider == "" {
provider = "doubao"
}
// 序列化场景列表
scenesJSON, err := json.Marshal(req.Scenes)
if err != nil {
return nil, fmt.Errorf("failed to serialize scenes: %w", err)
}
s.log.Infow("Serialized scenes to JSON",
"scenes_count", len(req.Scenes),
"scenes_json", string(scenesJSON))
epID, _ := strconv.ParseUint(req.EpisodeID, 10, 32)
dramaID, _ := strconv.ParseUint(req.DramaID, 10, 32)
videoMerge := &models.VideoMerge{
EpisodeID: uint(epID),
DramaID: uint(dramaID),
Title: req.Title,
Provider: provider,
Model: &req.Model,
Scenes: scenesJSON,
Status: models.VideoMergeStatusPending,
}
if err := s.db.Create(videoMerge).Error; err != nil {
return nil, fmt.Errorf("failed to create merge record: %w", err)
}
go s.processMergeVideo(videoMerge.ID)
return videoMerge, nil
}
func (s *VideoMergeService) processMergeVideo(mergeID uint) {
var videoMerge models.VideoMerge
if err := s.db.First(&videoMerge, mergeID).Error; err != nil {
s.log.Errorw("Failed to load video merge", "error", err, "id", mergeID)
return
}
s.db.Model(&videoMerge).Update("status", models.VideoMergeStatusProcessing)
client, err := s.getVideoClient(videoMerge.Provider)
if err != nil {
s.updateMergeError(mergeID, err.Error())
return
}
// 解析场景列表
var scenes []models.SceneClip
if err := json.Unmarshal(videoMerge.Scenes, &scenes); err != nil {
s.updateMergeError(mergeID, fmt.Sprintf("failed to parse scenes: %v", err))
return
}
// 调用视频合并API
result, err := s.mergeVideoClips(client, scenes)
if err != nil {
s.updateMergeError(mergeID, err.Error())
return
}
if !result.Completed {
s.db.Model(&videoMerge).Updates(map[string]interface{}{
"status": models.VideoMergeStatusProcessing,
"task_id": result.TaskID,
})
go s.pollMergeStatus(mergeID, client, result.TaskID)
return
}
s.completeMerge(mergeID, result)
}
func (s *VideoMergeService) mergeVideoClips(client video.VideoClient, scenes []models.SceneClip) (*video.VideoResult, error) {
if len(scenes) == 0 {
return nil, fmt.Errorf("no scenes to merge")
}
// 按Order字段排序场景
sort.Slice(scenes, func(i, j int) bool {
return scenes[i].Order < scenes[j].Order
})
s.log.Infow("Merging video clips with FFmpeg", "scene_count", len(scenes))
// 计算总时长
var totalDuration float64
for _, scene := range scenes {
totalDuration += scene.Duration
}
// 准备FFmpeg合成选项
clips := make([]ffmpeg.VideoClip, len(scenes))
for i, scene := range scenes {
clips[i] = ffmpeg.VideoClip{
URL: scene.VideoURL,
Duration: scene.Duration,
StartTime: scene.StartTime,
EndTime: scene.EndTime,
Transition: scene.Transition,
}
s.log.Infow("Clip added to merge queue",
"order", scene.Order,
"index", i,
"duration", scene.Duration,
"start_time", scene.StartTime,
"end_time", scene.EndTime)
}
// 创建视频输出目录
videoDir := filepath.Join(s.storagePath, "videos", "merged")
if err := os.MkdirAll(videoDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create video directory: %w", err)
}
// 生成输出文件名
fileName := fmt.Sprintf("merged_%d.mp4", time.Now().Unix())
outputPath := filepath.Join(videoDir, fileName)
// 使用FFmpeg合成视频
mergedPath, err := s.ffmpeg.MergeVideos(&ffmpeg.MergeOptions{
OutputPath: outputPath,
Clips: clips,
})
if err != nil {
return nil, fmt.Errorf("ffmpeg merge failed: %w", err)
}
s.log.Infow("Video merged successfully", "path", mergedPath)
// 生成访问URL相对路径
relPath := filepath.Join("videos", "merged", fileName)
videoURL := fmt.Sprintf("%s/%s", s.baseURL, relPath)
result := &video.VideoResult{
VideoURL: videoURL, // 返回可访问的URL
Duration: int(totalDuration),
Completed: true,
Status: "completed",
}
return result, nil
}
func (s *VideoMergeService) pollMergeStatus(mergeID uint, client video.VideoClient, taskID string) {
maxAttempts := 240
pollInterval := 5 * time.Second
for i := 0; i < maxAttempts; i++ {
time.Sleep(pollInterval)
result, err := client.GetTaskStatus(taskID)
if err != nil {
s.log.Errorw("Failed to get merge task status", "error", err, "task_id", taskID)
continue
}
if result.Completed {
s.completeMerge(mergeID, result)
return
}
if result.Error != "" {
s.updateMergeError(mergeID, result.Error)
return
}
}
s.updateMergeError(mergeID, "timeout: video merge took too long")
}
func (s *VideoMergeService) completeMerge(mergeID uint, result *video.VideoResult) {
now := time.Now()
// 获取merge记录
var videoMerge models.VideoMerge
if err := s.db.First(&videoMerge, mergeID).Error; err != nil {
s.log.Errorw("Failed to load video merge for completion", "error", err, "id", mergeID)
return
}
finalVideoURL := result.VideoURL
// 使用本地存储不再使用MinIO
s.log.Infow("Video merge completed, using local storage", "merge_id", mergeID, "local_path", result.VideoURL)
updates := map[string]interface{}{
"status": models.VideoMergeStatusCompleted,
"merged_url": finalVideoURL,
"completed_at": now,
}
if result.Duration > 0 {
updates["duration"] = result.Duration
}
s.db.Model(&models.VideoMerge{}).Where("id = ?", mergeID).Updates(updates)
// 更新episode的状态和最终视频URL
if videoMerge.EpisodeID != 0 {
s.db.Model(&models.Episode{}).Where("id = ?", videoMerge.EpisodeID).Updates(map[string]interface{}{
"status": "completed",
"video_url": finalVideoURL,
})
s.log.Infow("Episode finalized", "episode_id", videoMerge.EpisodeID, "video_url", finalVideoURL)
}
s.log.Infow("Video merge completed", "id", mergeID, "url", finalVideoURL)
}
func (s *VideoMergeService) updateMergeError(mergeID uint, errorMsg string) {
s.db.Model(&models.VideoMerge{}).Where("id = ?", mergeID).Updates(map[string]interface{}{
"status": models.VideoMergeStatusFailed,
"error_msg": errorMsg,
})
s.log.Errorw("Video merge failed", "id", mergeID, "error", errorMsg)
}
func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient, error) {
config, err := s.aiService.GetDefaultConfig("video")
if err != nil {
return nil, fmt.Errorf("failed to get video config: %w", err)
}
// 使用第一个模型
model := ""
if len(config.Model) > 0 {
model = config.Model[0]
}
// 根据配置中的 provider 创建对应的客户端
var endpoint string
var queryEndpoint string
switch config.Provider {
case "runway":
return video.NewRunwayClient(config.BaseURL, config.APIKey, model), nil
case "pika":
return video.NewPikaClient(config.BaseURL, config.APIKey, model), nil
case "openai", "sora":
return video.NewOpenAISoraClient(config.BaseURL, config.APIKey, model), nil
case "minimax":
return video.NewMinimaxClient(config.BaseURL, config.APIKey, model), nil
case "chatfire":
endpoint = "/video/generations"
queryEndpoint = "/video/task/{taskId}"
return video.NewChatfireClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "doubao", "volces", "ark":
endpoint = "/contents/generations/tasks"
queryEndpoint = "/generations/tasks/{taskId}"
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
default:
endpoint = "/contents/generations/tasks"
queryEndpoint = "/generations/tasks/{taskId}"
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
}
}
func (s *VideoMergeService) GetMerge(mergeID uint) (*models.VideoMerge, error) {
var merge models.VideoMerge
if err := s.db.Where("id = ? ", mergeID).First(&merge).Error; err != nil {
return nil, err
}
return &merge, nil
}
func (s *VideoMergeService) ListMerges(episodeID *string, status string, page, pageSize int) ([]models.VideoMerge, int64, error) {
query := s.db.Model(&models.VideoMerge{})
if episodeID != nil && *episodeID != "" {
query = query.Where("episode_id = ?", *episodeID)
}
if status != "" {
query = query.Where("status = ?", status)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
var merges []models.VideoMerge
offset := (page - 1) * pageSize
if err := query.Order("created_at DESC").Offset(offset).Limit(pageSize).Find(&merges).Error; err != nil {
return nil, 0, err
}
return merges, total, nil
}
func (s *VideoMergeService) DeleteMerge(mergeID uint) error {
result := s.db.Where("id = ? ", mergeID).Delete(&models.VideoMerge{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("merge not found")
}
return nil
}
// TimelineClip 时间线片段数据
type TimelineClip struct {
AssetID string `json:"asset_id"` // 素材库视频ID优先使用
StoryboardID string `json:"storyboard_id"` // 分镜IDfallback
Order int `json:"order"`
StartTime float64 `json:"start_time"`
EndTime float64 `json:"end_time"`
Duration float64 `json:"duration"`
Transition map[string]interface{} `json:"transition"`
}
// FinalizeEpisodeRequest 完成剧集制作请求
type FinalizeEpisodeRequest struct {
EpisodeID string `json:"episode_id"`
Clips []TimelineClip `json:"clips"`
}
// FinalizeEpisode 完成集数制作,根据时间线场景顺序合成最终视频
func (s *VideoMergeService) FinalizeEpisode(episodeID string, timelineData *FinalizeEpisodeRequest) (map[string]interface{}, error) {
// 验证episode存在且属于该用户
var episode models.Episode
if err := s.db.Preload("Drama").Preload("Storyboards").Where("id = ?", episodeID).First(&episode).Error; err != nil {
return nil, fmt.Errorf("episode not found")
}
// 构建分镜ID映射
sceneMap := make(map[string]models.Storyboard)
for _, scene := range episode.Storyboards {
sceneMap[fmt.Sprintf("%d", scene.ID)] = scene
}
// 根据时间线数据构建场景片段
var sceneClips []models.SceneClip
var skippedScenes []int
if timelineData != nil && len(timelineData.Clips) > 0 {
// 使用前端提供的时间线数据
for _, clip := range timelineData.Clips {
// 优先使用素材库中的视频通过AssetID
var videoURL string
var sceneID uint
if clip.AssetID != "" {
// 从素材库获取视频URL
var asset models.Asset
if err := s.db.Where("id = ? AND type = ?", clip.AssetID, models.AssetTypeVideo).First(&asset).Error; err == nil {
videoURL = asset.URL
// 如果asset关联了storyboard使用关联的storyboard_id
if asset.StoryboardID != nil {
sceneID = *asset.StoryboardID
}
s.log.Infow("Using video from asset library", "asset_id", clip.AssetID, "video_url", videoURL)
} else {
s.log.Warnw("Asset not found, will try storyboard video", "asset_id", clip.AssetID, "error", err)
}
}
// 如果没有从素材库获取到视频尝试从storyboard获取
if videoURL == "" && clip.StoryboardID != "" {
scene, exists := sceneMap[clip.StoryboardID]
if !exists {
s.log.Warnw("Storyboard not found in episode, skipping", "storyboard_id", clip.StoryboardID)
continue
}
if scene.VideoURL != nil && *scene.VideoURL != "" {
videoURL = *scene.VideoURL
sceneID = scene.ID
s.log.Infow("Using video from storyboard", "storyboard_id", clip.StoryboardID, "video_url", videoURL)
}
}
// 如果仍然没有视频URL跳过该片段
if videoURL == "" {
s.log.Warnw("No video available for clip, skipping", "clip", clip)
if clip.StoryboardID != "" {
if scene, exists := sceneMap[clip.StoryboardID]; exists {
skippedScenes = append(skippedScenes, scene.StoryboardNumber)
}
}
continue
}
sceneClip := models.SceneClip{
SceneID: sceneID,
VideoURL: videoURL,
Duration: clip.Duration,
Order: clip.Order,
StartTime: clip.StartTime,
EndTime: clip.EndTime,
Transition: clip.Transition,
}
s.log.Infow("Adding scene clip with transition",
"scene_id", sceneID,
"order", clip.Order,
"transition", clip.Transition)
sceneClips = append(sceneClips, sceneClip)
}
} else {
// 没有时间线数据,使用默认场景顺序
if len(episode.Storyboards) == 0 {
return nil, fmt.Errorf("no scenes found for this episode")
}
order := 0
for _, scene := range episode.Storyboards {
// 优先从素材库查找该分镜关联的视频
var videoURL string
var asset models.Asset
if err := s.db.Where("storyboard_id = ? AND type = ? AND episode_id = ?",
scene.ID, models.AssetTypeVideo, episode.ID).
Order("created_at DESC").
First(&asset).Error; err == nil {
videoURL = asset.URL
s.log.Infow("Using video from asset library for storyboard",
"storyboard_id", scene.ID,
"asset_id", asset.ID,
"video_url", videoURL)
} else if scene.VideoURL != nil && *scene.VideoURL != "" {
// 如果素材库没有使用storyboard的video_url作为fallback
videoURL = *scene.VideoURL
s.log.Infow("Using fallback video from storyboard",
"storyboard_id", scene.ID,
"video_url", videoURL)
}
// 跳过没有视频的场景
if videoURL == "" {
s.log.Warnw("Scene has no video, skipping", "storyboard_number", scene.StoryboardNumber)
skippedScenes = append(skippedScenes, scene.StoryboardNumber)
continue
}
clip := models.SceneClip{
SceneID: scene.ID,
VideoURL: videoURL,
Duration: float64(scene.Duration),
Order: order,
}
sceneClips = append(sceneClips, clip)
order++
}
}
// 检查是否至少有一个场景可以合成
if len(sceneClips) == 0 {
return nil, fmt.Errorf("no scenes with videos available for merging")
}
// 创建视频合成任务
title := fmt.Sprintf("%s - 第%d集", episode.Drama.Title, episode.EpisodeNum)
finalReq := &MergeVideoRequest{
EpisodeID: episodeID,
DramaID: fmt.Sprintf("%d", episode.DramaID),
Title: title,
Scenes: sceneClips,
Provider: "doubao", // 默认使用doubao
}
// 执行视频合成
videoMerge, err := s.MergeVideos(finalReq)
if err != nil {
return nil, fmt.Errorf("failed to start video merge: %w", err)
}
// 更新episode状态为processing
s.db.Model(&episode).Updates(map[string]interface{}{
"status": "processing",
})
result := map[string]interface{}{
"message": "视频合成任务已创建,正在后台处理",
"merge_id": videoMerge.ID,
"episode_id": episodeID,
"scenes_count": len(sceneClips),
}
// 如果有跳过的场景,添加提示信息
if len(skippedScenes) > 0 {
result["skipped_scenes"] = skippedScenes
result["warning"] = fmt.Sprintf("已跳过 %d 个未生成视频的场景(场景编号:%v", len(skippedScenes), skippedScenes)
}
return result, nil
}