create project

This commit is contained in:
2026-01-16 17:30:40 +08:00
commit effac6b017
157 changed files with 45997 additions and 0 deletions

View File

@@ -0,0 +1,566 @@
package services
import (
"encoding/json"
"fmt"
"strconv"
"time"
models "github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/infrastructure/storage"
"github.com/drama-generator/backend/pkg/logger"
"github.com/drama-generator/backend/pkg/video"
"gorm.io/gorm"
)
type VideoGenerationService struct {
db *gorm.DB
transferService *ResourceTransferService
log *logger.Logger
localStorage *storage.LocalStorage
aiService *AIService
}
func NewVideoGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, aiService *AIService, log *logger.Logger) *VideoGenerationService {
service := &VideoGenerationService{
db: db,
localStorage: localStorage,
transferService: transferService,
aiService: aiService,
log: log,
}
go service.RecoverPendingTasks()
return service
}
type GenerateVideoRequest struct {
StoryboardID *uint `json:"storyboard_id"`
DramaID string `json:"drama_id" binding:"required"`
ImageGenID *uint `json:"image_gen_id"`
// 参考图模式single, first_last, multiple, none
ReferenceMode string `json:"reference_mode"`
// 单图模式
ImageURL string `json:"image_url"`
// 首尾帧模式
FirstFrameURL *string `json:"first_frame_url"`
LastFrameURL *string `json:"last_frame_url"`
// 多图模式
ReferenceImageURLs []string `json:"reference_image_urls"`
Prompt string `json:"prompt" binding:"required,min=5,max=2000"`
Provider string `json:"provider"`
Model string `json:"model"`
Duration *int `json:"duration"`
FPS *int `json:"fps"`
AspectRatio *string `json:"aspect_ratio"`
Style *string `json:"style"`
MotionLevel *int `json:"motion_level"`
CameraMotion *string `json:"camera_motion"`
Seed *int64 `json:"seed"`
}
func (s *VideoGenerationService) GenerateVideo(request *GenerateVideoRequest) (*models.VideoGeneration, error) {
if request.StoryboardID != nil {
var storyboard models.Storyboard
if err := s.db.Preload("Episode").Where("id = ?", *request.StoryboardID).First(&storyboard).Error; err != nil {
return nil, fmt.Errorf("storyboard not found")
}
if fmt.Sprintf("%d", storyboard.Episode.DramaID) != request.DramaID {
return nil, fmt.Errorf("storyboard does not belong to drama")
}
}
if request.ImageGenID != nil {
var imageGen models.ImageGeneration
if err := s.db.Where("id = ?", *request.ImageGenID).First(&imageGen).Error; err != nil {
return nil, fmt.Errorf("image generation not found")
}
}
provider := request.Provider
if provider == "" {
provider = "doubao"
}
dramaID, _ := strconv.ParseUint(request.DramaID, 10, 32)
videoGen := &models.VideoGeneration{
StoryboardID: request.StoryboardID,
DramaID: uint(dramaID),
ImageGenID: request.ImageGenID,
Provider: provider,
Prompt: request.Prompt,
Model: request.Model,
Duration: request.Duration,
FPS: request.FPS,
AspectRatio: request.AspectRatio,
Style: request.Style,
MotionLevel: request.MotionLevel,
CameraMotion: request.CameraMotion,
Seed: request.Seed,
Status: models.VideoStatusPending,
}
// 根据参考图模式处理不同的参数
if request.ReferenceMode != "" {
videoGen.ReferenceMode = &request.ReferenceMode
}
switch request.ReferenceMode {
case "single":
// 单图模式
if request.ImageURL != "" {
videoGen.ImageURL = &request.ImageURL
}
case "first_last":
// 首尾帧模式
if request.FirstFrameURL != nil {
videoGen.FirstFrameURL = request.FirstFrameURL
}
if request.LastFrameURL != nil {
videoGen.LastFrameURL = request.LastFrameURL
}
case "multiple":
// 多图模式
if len(request.ReferenceImageURLs) > 0 {
referenceImagesJSON, err := json.Marshal(request.ReferenceImageURLs)
if err == nil {
referenceImagesStr := string(referenceImagesJSON)
videoGen.ReferenceImageURLs = &referenceImagesStr
}
}
case "none":
// 无参考图,纯文本生成
default:
// 向后兼容:如果没有指定模式,根据提供的参数自动判断
if request.ImageURL != "" {
videoGen.ImageURL = &request.ImageURL
mode := "single"
videoGen.ReferenceMode = &mode
} else if request.FirstFrameURL != nil || request.LastFrameURL != nil {
videoGen.FirstFrameURL = request.FirstFrameURL
videoGen.LastFrameURL = request.LastFrameURL
mode := "first_last"
videoGen.ReferenceMode = &mode
} else if len(request.ReferenceImageURLs) > 0 {
referenceImagesJSON, err := json.Marshal(request.ReferenceImageURLs)
if err == nil {
referenceImagesStr := string(referenceImagesJSON)
videoGen.ReferenceImageURLs = &referenceImagesStr
mode := "multiple"
videoGen.ReferenceMode = &mode
}
}
}
if err := s.db.Create(videoGen).Error; err != nil {
return nil, fmt.Errorf("failed to create record: %w", err)
}
go s.ProcessVideoGeneration(videoGen.ID)
return videoGen, nil
}
func (s *VideoGenerationService) ProcessVideoGeneration(videoGenID uint) {
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, videoGenID).Error; err != nil {
s.log.Errorw("Failed to load video generation", "error", err, "id", videoGenID)
return
}
s.db.Model(&videoGen).Update("status", models.VideoStatusProcessing)
client, err := s.getVideoClient(videoGen.Provider, videoGen.Model)
if err != nil {
s.log.Errorw("Failed to get video client", "error", err, "provider", videoGen.Provider, "model", videoGen.Model)
s.updateVideoGenError(videoGenID, err.Error())
return
}
s.log.Infow("Starting video generation", "id", videoGenID, "prompt", videoGen.Prompt, "provider", videoGen.Provider)
var opts []video.VideoOption
if videoGen.Model != "" {
opts = append(opts, video.WithModel(videoGen.Model))
}
if videoGen.Duration != nil {
opts = append(opts, video.WithDuration(*videoGen.Duration))
}
if videoGen.FPS != nil {
opts = append(opts, video.WithFPS(*videoGen.FPS))
}
if videoGen.AspectRatio != nil {
opts = append(opts, video.WithAspectRatio(*videoGen.AspectRatio))
}
if videoGen.Style != nil {
opts = append(opts, video.WithStyle(*videoGen.Style))
}
if videoGen.MotionLevel != nil {
opts = append(opts, video.WithMotionLevel(*videoGen.MotionLevel))
}
if videoGen.CameraMotion != nil {
opts = append(opts, video.WithCameraMotion(*videoGen.CameraMotion))
}
if videoGen.Seed != nil {
opts = append(opts, video.WithSeed(*videoGen.Seed))
}
// 根据参考图模式添加相应的选项
if videoGen.ReferenceMode != nil {
switch *videoGen.ReferenceMode {
case "first_last":
// 首尾帧模式
if videoGen.FirstFrameURL != nil {
opts = append(opts, video.WithFirstFrame(*videoGen.FirstFrameURL))
}
if videoGen.LastFrameURL != nil {
opts = append(opts, video.WithLastFrame(*videoGen.LastFrameURL))
}
case "multiple":
// 多图模式
if videoGen.ReferenceImageURLs != nil {
var imageURLs []string
if err := json.Unmarshal([]byte(*videoGen.ReferenceImageURLs), &imageURLs); err == nil {
opts = append(opts, video.WithReferenceImages(imageURLs))
}
}
}
}
// 构造imageURL参数单图模式使用其他模式传空字符串
imageURL := ""
if videoGen.ImageURL != nil {
imageURL = *videoGen.ImageURL
}
result, err := client.GenerateVideo(imageURL, videoGen.Prompt, opts...)
if err != nil {
s.log.Errorw("Video generation API call failed", "error", err, "id", videoGenID)
s.updateVideoGenError(videoGenID, err.Error())
return
}
if result.TaskID != "" {
s.db.Model(&videoGen).Updates(map[string]interface{}{
"task_id": result.TaskID,
"status": models.VideoStatusProcessing,
})
go s.pollTaskStatus(videoGenID, result.TaskID, videoGen.Provider, videoGen.Model)
return
}
if result.VideoURL != "" {
s.completeVideoGeneration(videoGenID, result.VideoURL, &result.Duration, &result.Width, &result.Height, nil)
return
}
s.updateVideoGenError(videoGenID, "no task ID or video URL returned")
}
func (s *VideoGenerationService) pollTaskStatus(videoGenID uint, taskID string, provider string, model string) {
client, err := s.getVideoClient(provider, model)
if err != nil {
s.log.Errorw("Failed to get video client for polling", "error", err)
s.updateVideoGenError(videoGenID, "failed to get video client")
return
}
maxAttempts := 300
interval := 10 * time.Second
for attempt := 0; attempt < maxAttempts; attempt++ {
time.Sleep(interval)
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, videoGenID).Error; err != nil {
s.log.Errorw("Failed to load video generation", "error", err, "id", videoGenID)
return
}
if videoGen.Status != models.VideoStatusProcessing {
s.log.Infow("Video generation status changed, stopping poll", "id", videoGenID, "status", videoGen.Status)
return
}
result, err := client.GetTaskStatus(taskID)
if err != nil {
s.log.Errorw("Failed to get task status", "error", err, "task_id", taskID)
continue
}
if result.Completed {
if result.VideoURL != "" {
s.completeVideoGeneration(videoGenID, result.VideoURL, &result.Duration, &result.Width, &result.Height, nil)
return
}
s.updateVideoGenError(videoGenID, "task completed but no video URL")
return
}
if result.Error != "" {
s.updateVideoGenError(videoGenID, result.Error)
return
}
s.log.Infow("Video generation in progress", "id", videoGenID, "attempt", attempt+1)
}
s.updateVideoGenError(videoGenID, "polling timeout")
}
func (s *VideoGenerationService) completeVideoGeneration(videoGenID uint, videoURL string, duration *int, width *int, height *int, firstFrameURL *string) {
// 下载视频到本地存储(仅用于缓存,不更新数据库)
if s.localStorage != nil && videoURL != "" {
_, err := s.localStorage.DownloadFromURL(videoURL, "videos")
if err != nil {
s.log.Warnw("Failed to download video to local storage",
"error", err,
"id", videoGenID,
"original_url", videoURL)
} else {
s.log.Infow("Video downloaded to local storage for caching",
"id", videoGenID,
"original_url", videoURL)
}
}
// 下载首帧图片到本地存储(仅用于缓存,不更新数据库)
if firstFrameURL != nil && *firstFrameURL != "" && s.localStorage != nil {
_, err := s.localStorage.DownloadFromURL(*firstFrameURL, "video_frames")
if err != nil {
s.log.Warnw("Failed to download first frame to local storage",
"error", err,
"id", videoGenID,
"original_url", *firstFrameURL)
} else {
s.log.Infow("First frame downloaded to local storage for caching",
"id", videoGenID,
"original_url", *firstFrameURL)
}
}
// 数据库中保持使用原始URL
updates := map[string]interface{}{
"status": models.VideoStatusCompleted,
"video_url": videoURL,
}
if duration != nil {
updates["duration"] = *duration
}
if width != nil {
updates["width"] = *width
}
if height != nil {
updates["height"] = *height
}
if firstFrameURL != nil {
updates["first_frame_url"] = *firstFrameURL
}
if err := s.db.Model(&models.VideoGeneration{}).Where("id = ?", videoGenID).Updates(updates).Error; err != nil {
s.log.Errorw("Failed to update video generation", "error", err, "id", videoGenID)
return
}
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, videoGenID).Error; err == nil {
if videoGen.StoryboardID != nil {
if err := s.db.Model(&models.Storyboard{}).Where("id = ?", *videoGen.StoryboardID).Update("video_url", videoURL).Error; err != nil {
s.log.Warnw("Failed to update storyboard video_url", "storyboard_id", *videoGen.StoryboardID, "error", err)
}
}
}
s.log.Infow("Video generation completed", "id", videoGenID, "url", videoURL)
}
func (s *VideoGenerationService) updateVideoGenError(videoGenID uint, errorMsg string) {
if err := s.db.Model(&models.VideoGeneration{}).Where("id = ?", videoGenID).Updates(map[string]interface{}{
"status": models.VideoStatusFailed,
"error_msg": errorMsg,
}).Error; err != nil {
s.log.Errorw("Failed to update video generation error", "error", err, "id", videoGenID)
}
}
func (s *VideoGenerationService) getVideoClient(provider string, modelName string) (video.VideoClient, error) {
// 根据模型名称获取AI配置
var config *models.AIServiceConfig
var err error
if modelName != "" {
config, err = s.aiService.GetConfigForModel("video", modelName)
if err != nil {
s.log.Warnw("Failed to get config for model, using default", "model", modelName, "error", err)
config, err = s.aiService.GetDefaultConfig("video")
if err != nil {
return nil, fmt.Errorf("no video AI config found: %w", err)
}
}
} else {
config, err = s.aiService.GetDefaultConfig("video")
if err != nil {
return nil, fmt.Errorf("no video AI config found: %w", err)
}
}
// 使用配置中的信息创建客户端
baseURL := config.BaseURL
apiKey := config.APIKey
model := modelName
if model == "" && len(config.Model) > 0 {
model = config.Model[0]
}
// 根据配置中的 provider 创建对应的客户端
var endpoint string
var queryEndpoint string
switch config.Provider {
case "chatfire":
endpoint = "/video/generations"
queryEndpoint = "/video/task/{taskId}"
return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
case "doubao", "volcengine", "volces":
endpoint = "/contents/generations/tasks"
queryEndpoint = "/contents/generations/tasks/{taskId}"
return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
case "openai":
// OpenAI Sora 使用 /v1/videos 端点
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
case "runway":
return video.NewRunwayClient(baseURL, apiKey, model), nil
case "pika":
return video.NewPikaClient(baseURL, apiKey, model), nil
case "minimax":
return video.NewMinimaxClient(baseURL, apiKey, model), nil
default:
return nil, fmt.Errorf("unsupported video provider: %s", provider)
}
}
func (s *VideoGenerationService) RecoverPendingTasks() {
var pendingVideos []models.VideoGeneration
if err := s.db.Where("status = ? AND task_id != ''", models.VideoStatusProcessing).Find(&pendingVideos).Error; err != nil {
s.log.Errorw("Failed to load pending video tasks", "error", err)
return
}
s.log.Infow("Recovering pending video generation tasks", "count", len(pendingVideos))
for _, videoGen := range pendingVideos {
go s.pollTaskStatus(videoGen.ID, *videoGen.TaskID, videoGen.Provider, videoGen.Model)
}
}
func (s *VideoGenerationService) GetVideoGeneration(id uint) (*models.VideoGeneration, error) {
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, id).Error; err != nil {
return nil, err
}
return &videoGen, nil
}
func (s *VideoGenerationService) ListVideoGenerations(dramaID *uint, storyboardID *uint, status string, limit int, offset int) ([]*models.VideoGeneration, int64, error) {
var videos []*models.VideoGeneration
var total int64
query := s.db.Model(&models.VideoGeneration{})
if dramaID != nil {
query = query.Where("drama_id = ?", *dramaID)
}
if storyboardID != nil {
query = query.Where("storyboard_id = ?", *storyboardID)
}
if status != "" {
query = query.Where("status = ?", status)
}
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Limit(limit).Offset(offset).Find(&videos).Error; err != nil {
return nil, 0, err
}
return videos, total, nil
}
func (s *VideoGenerationService) GenerateVideoFromImage(imageGenID uint) (*models.VideoGeneration, error) {
var imageGen models.ImageGeneration
if err := s.db.First(&imageGen, imageGenID).Error; err != nil {
return nil, fmt.Errorf("image generation not found")
}
if imageGen.Status != models.ImageStatusCompleted || imageGen.ImageURL == nil {
return nil, fmt.Errorf("image is not ready")
}
// 获取关联的Storyboard以获取时长
var duration *int
if imageGen.StoryboardID != nil {
var storyboard models.Storyboard
if err := s.db.Where("id = ?", *imageGen.StoryboardID).First(&storyboard).Error; err == nil {
duration = &storyboard.Duration
s.log.Infow("Using storyboard duration for video generation",
"storyboard_id", *imageGen.StoryboardID,
"duration", storyboard.Duration)
}
}
req := &GenerateVideoRequest{
DramaID: fmt.Sprintf("%d", imageGen.DramaID),
StoryboardID: imageGen.StoryboardID,
ImageGenID: &imageGenID,
ImageURL: *imageGen.ImageURL,
Prompt: imageGen.Prompt,
Provider: "doubao",
Duration: duration,
}
return s.GenerateVideo(req)
}
func (s *VideoGenerationService) BatchGenerateVideosForEpisode(episodeID string) ([]*models.VideoGeneration, error) {
var episode models.Episode
if err := s.db.Preload("Storyboards").Where("id = ?", episodeID).First(&episode).Error; err != nil {
return nil, fmt.Errorf("episode not found")
}
var results []*models.VideoGeneration
for _, storyboard := range episode.Storyboards {
if storyboard.ImagePrompt == nil {
continue
}
var imageGen models.ImageGeneration
if err := s.db.Where("storyboard_id = ? AND status = ?", storyboard.ID, models.ImageStatusCompleted).
Order("created_at DESC").First(&imageGen).Error; err != nil {
s.log.Warnw("No completed image for storyboard", "storyboard_id", storyboard.ID)
continue
}
videoGen, err := s.GenerateVideoFromImage(imageGen.ID)
if err != nil {
s.log.Errorw("Failed to generate video", "storyboard_id", storyboard.ID, "error", err)
continue
}
results = append(results, videoGen)
}
return results, nil
}
func (s *VideoGenerationService) DeleteVideoGeneration(id uint) error {
return s.db.Delete(&models.VideoGeneration{}, id).Error
}