create project

This commit is contained in:
2026-01-16 17:30:40 +08:00
commit effac6b017
157 changed files with 45997 additions and 0 deletions

7
pkg/ai/client.go Normal file
View File

@@ -0,0 +1,7 @@
package ai
// AIClient 定义文本生成客户端接口
type AIClient interface {
GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error)
TestConnection() error
}

195
pkg/ai/gemini_client.go Normal file
View File

@@ -0,0 +1,195 @@
package ai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type GeminiClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
HTTPClient *http.Client
}
type GeminiTextRequest struct {
Contents []GeminiContent `json:"contents"`
SystemInstruction *GeminiInstruction `json:"systemInstruction,omitempty"`
}
type GeminiContent struct {
Parts []GeminiPart `json:"parts"`
Role string `json:"role,omitempty"`
}
type GeminiPart struct {
Text string `json:"text"`
}
type GeminiInstruction struct {
Parts []GeminiPart `json:"parts"`
}
type GeminiTextResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
} `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
SafetyRatings []struct {
Category string `json:"category"`
Probability string `json:"probability"`
} `json:"safetyRatings"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
func NewGeminiClient(baseURL, apiKey, model, endpoint string) *GeminiClient {
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
if endpoint == "" {
endpoint = "/v1beta/models/{model}:generateContent"
}
if model == "" {
model = "gemini-3-pro"
}
return &GeminiClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *GeminiClient) GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error) {
model := c.Model
// 构建请求体
reqBody := GeminiTextRequest{
Contents: []GeminiContent{
{
Parts: []GeminiPart{{Text: prompt}},
Role: "user",
},
},
}
// 使用 systemInstruction 字段处理系统提示
if systemPrompt != "" {
reqBody.SystemInstruction = &GeminiInstruction{
Parts: []GeminiPart{{Text: systemPrompt}},
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
fmt.Printf("Gemini: Failed to marshal request: %v\n", err)
return "", fmt.Errorf("marshal request: %w", err)
}
// 替换端点中的 {model} 占位符
endpoint := c.BaseURL + c.Endpoint
endpoint = strings.ReplaceAll(endpoint, "{model}", model)
url := fmt.Sprintf("%s?key=%s", endpoint, c.APIKey)
// 打印请求信息(隐藏 API Key
safeURL := strings.Replace(url, c.APIKey, "***", 1)
fmt.Printf("Gemini: Sending request to: %s\n", safeURL)
requestPreview := string(jsonData)
if len(jsonData) > 300 {
requestPreview = string(jsonData[:300]) + "..."
}
fmt.Printf("Gemini: Request body: %s\n", requestPreview)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
fmt.Printf("Gemini: Failed to create request: %v\n", err)
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
fmt.Printf("Gemini: Executing HTTP request...\n")
resp, err := c.HTTPClient.Do(req)
if err != nil {
fmt.Printf("Gemini: HTTP request failed: %v\n", err)
return "", fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
fmt.Printf("Gemini: Received response with status: %d\n", resp.StatusCode)
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("Gemini: Failed to read response body: %v\n", err)
return "", fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
fmt.Printf("Gemini: API error (status %d): %s\n", resp.StatusCode, string(body))
return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
// 打印响应体用于调试
bodyPreview := string(body)
if len(body) > 500 {
bodyPreview = string(body[:500]) + "..."
}
fmt.Printf("Gemini: Response body: %s\n", bodyPreview)
var result GeminiTextResponse
if err := json.Unmarshal(body, &result); err != nil {
errorPreview := string(body)
if len(body) > 200 {
errorPreview = string(body[:200])
}
fmt.Printf("Gemini: Failed to parse response: %v\n", err)
return "", fmt.Errorf("parse response: %w, body preview: %s", err, errorPreview)
}
fmt.Printf("Gemini: Successfully parsed response, candidates count: %d\n", len(result.Candidates))
if len(result.Candidates) == 0 {
fmt.Printf("Gemini: No candidates in response\n")
return "", fmt.Errorf("no candidates in response")
}
if len(result.Candidates[0].Content.Parts) == 0 {
fmt.Printf("Gemini: No parts in first candidate\n")
return "", fmt.Errorf("no parts in response")
}
responseText := result.Candidates[0].Content.Parts[0].Text
fmt.Printf("Gemini: Generated text: %s\n", responseText)
return responseText, nil
}
func (c *GeminiClient) TestConnection() error {
fmt.Printf("Gemini: TestConnection called with BaseURL=%s, Model=%s, Endpoint=%s\n", c.BaseURL, c.Model, c.Endpoint)
_, err := c.GenerateText("Hello", "")
if err != nil {
fmt.Printf("Gemini: TestConnection failed: %v\n", err)
} else {
fmt.Printf("Gemini: TestConnection succeeded\n")
}
return err
}

227
pkg/ai/openai_client.go Normal file
View File

@@ -0,0 +1,227 @@
package ai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type OpenAIClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
HTTPClient *http.Client
}
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
}
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
type ErrorResponse struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code"`
} `json:"error"`
}
func NewOpenAIClient(baseURL, apiKey, model, endpoint string) *OpenAIClient {
if endpoint == "" {
endpoint = "/v1/chat/completions"
}
return &OpenAIClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*ChatCompletionRequest)) (*ChatCompletionResponse, error) {
req := &ChatCompletionRequest{
Model: c.Model,
Messages: messages,
}
for _, option := range options {
option(req)
}
return c.sendChatRequest(req)
}
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
jsonData, err := json.Marshal(req)
if err != nil {
fmt.Printf("OpenAI: Failed to marshal request: %v\n", err)
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := c.BaseURL + c.Endpoint
// 打印请求信息
fmt.Printf("OpenAI: Sending request to: %s\n", url)
fmt.Printf("OpenAI: BaseURL=%s, Endpoint=%s, Model=%s\n", c.BaseURL, c.Endpoint, c.Model)
requestPreview := string(jsonData)
if len(jsonData) > 300 {
requestPreview = string(jsonData[:300]) + "..."
}
fmt.Printf("OpenAI: Request body: %s\n", requestPreview)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
fmt.Printf("OpenAI: Failed to create request: %v\n", err)
return nil, fmt.Errorf("failed to create request: %w", err)
}
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+c.APIKey)
fmt.Printf("OpenAI: Executing HTTP request...\n")
resp, err := c.HTTPClient.Do(httpReq)
if err != nil {
fmt.Printf("OpenAI: HTTP request failed: %v\n", err)
return nil, fmt.Errorf("failed to send request: %w", err)
}
defer resp.Body.Close()
fmt.Printf("OpenAI: Received response with status: %d\n", resp.StatusCode)
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("OpenAI: Failed to read response body: %v\n", err)
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
fmt.Printf("OpenAI: API error (status %d): %s\n", resp.StatusCode, string(body))
var errResp ErrorResponse
if err := json.Unmarshal(body, &errResp); err != nil {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
return nil, fmt.Errorf("API error: %s", errResp.Error.Message)
}
// 打印响应体用于调试
bodyPreview := string(body)
if len(body) > 500 {
bodyPreview = string(body[:500]) + "..."
}
fmt.Printf("OpenAI: Response body: %s\n", bodyPreview)
var chatResp ChatCompletionResponse
if err := json.Unmarshal(body, &chatResp); err != nil {
errorPreview := string(body)
if len(body) > 200 {
errorPreview = string(body[:200])
}
fmt.Printf("OpenAI: Failed to parse response: %v\n", err)
return nil, fmt.Errorf("failed to unmarshal response: %w, body preview: %s", err, errorPreview)
}
fmt.Printf("OpenAI: Successfully parsed response, choices count: %d\n", len(chatResp.Choices))
return &chatResp, nil
}
func WithTemperature(temp float64) func(*ChatCompletionRequest) {
return func(req *ChatCompletionRequest) {
req.Temperature = temp
}
}
func WithMaxTokens(tokens int) func(*ChatCompletionRequest) {
return func(req *ChatCompletionRequest) {
req.MaxTokens = tokens
}
}
func WithTopP(topP float64) func(*ChatCompletionRequest) {
return func(req *ChatCompletionRequest) {
req.TopP = topP
}
}
func (c *OpenAIClient) GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error) {
messages := []ChatMessage{}
if systemPrompt != "" {
messages = append(messages, ChatMessage{
Role: "system",
Content: systemPrompt,
})
}
messages = append(messages, ChatMessage{
Role: "user",
Content: prompt,
})
resp, err := c.ChatCompletion(messages, options...)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", fmt.Errorf("no response from API")
}
return resp.Choices[0].Message.Content, nil
}
func (c *OpenAIClient) TestConnection() error {
fmt.Printf("OpenAI: TestConnection called with BaseURL=%s, Endpoint=%s, Model=%s\n", c.BaseURL, c.Endpoint, c.Model)
messages := []ChatMessage{
{
Role: "user",
Content: "Hello",
},
}
_, err := c.ChatCompletion(messages, WithMaxTokens(10))
if err != nil {
fmt.Printf("OpenAI: TestConnection failed: %v\n", err)
} else {
fmt.Printf("OpenAI: TestConnection succeeded\n")
}
return err
}

89
pkg/config/config.go Normal file
View File

@@ -0,0 +1,89 @@
package config
import (
"fmt"
"github.com/spf13/viper"
)
type Config struct {
App AppConfig `mapstructure:"app"`
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Storage StorageConfig `mapstructure:"storage"`
AI AIConfig `mapstructure:"ai"`
}
type AppConfig struct {
Name string `mapstructure:"name"`
Version string `mapstructure:"version"`
Debug bool `mapstructure:"debug"`
}
type ServerConfig struct {
Port int `mapstructure:"port"`
Host string `mapstructure:"host"`
CORSOrigins []string `mapstructure:"cors_origins"`
ReadTimeout int `mapstructure:"read_timeout"`
WriteTimeout int `mapstructure:"write_timeout"`
}
type DatabaseConfig struct {
Type string `mapstructure:"type"` // sqlite, mysql
Path string `mapstructure:"path"` // SQLite数据库文件路径
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Database string `mapstructure:"database"`
Charset string `mapstructure:"charset"`
MaxIdle int `mapstructure:"max_idle"`
MaxOpen int `mapstructure:"max_open"`
}
type StorageConfig struct {
Type string `mapstructure:"type"` // local, minio
LocalPath string `mapstructure:"local_path"` // 本地存储路径
BaseURL string `mapstructure:"base_url"` // 访问URL前缀
}
type AIConfig struct {
DefaultTextProvider string `mapstructure:"default_text_provider"`
DefaultImageProvider string `mapstructure:"default_image_provider"`
DefaultVideoProvider string `mapstructure:"default_video_provider"`
}
func LoadConfig() (*Config, error) {
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath("./configs")
viper.AddConfigPath(".")
viper.AutomaticEnv()
if err := viper.ReadInConfig(); err != nil {
return nil, fmt.Errorf("failed to read config: %w", err)
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
return &config, nil
}
func (c *DatabaseConfig) DSN() string {
if c.Type == "sqlite" {
return c.Path
}
// MySQL DSN
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
c.User,
c.Password,
c.Host,
c.Port,
c.Database,
c.Charset,
)
}

View File

@@ -0,0 +1,277 @@
package image
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type GeminiImageClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
HTTPClient *http.Client
}
type GeminiImageRequest struct {
Contents []struct {
Parts []GeminiPart `json:"parts"`
} `json:"contents"`
GenerationConfig struct {
ResponseModalities []string `json:"responseModalities"`
} `json:"generationConfig"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"` // base64 编码的图片数据
}
type GeminiImageResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
InlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
} `json:"inlineData,omitempty"`
Text string `json:"text,omitempty"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
// downloadImageToBase64 下载图片 URL 并转换为 base64
func downloadImageToBase64(imageURL string) (string, string, error) {
resp, err := http.Get(imageURL)
if err != nil {
return "", "", fmt.Errorf("download image: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", "", fmt.Errorf("download image failed with status: %d", resp.StatusCode)
}
imageData, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", fmt.Errorf("read image data: %w", err)
}
// 根据 Content-Type 确定 mimeType
mimeType := resp.Header.Get("Content-Type")
if mimeType == "" {
mimeType = "image/jpeg"
}
base64Data := base64.StdEncoding.EncodeToString(imageData)
return base64Data, mimeType, nil
}
func NewGeminiImageClient(baseURL, apiKey, model, endpoint string) *GeminiImageClient {
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
if endpoint == "" {
endpoint = "/v1beta/models/{model}:generateContent"
}
if model == "" {
model = "gemini-3-pro-image-preview"
}
return &GeminiImageClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *GeminiImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
options := &ImageOptions{
Size: "1024x1024",
Quality: "standard",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
promptText := prompt
if options.NegativePrompt != "" {
promptText += fmt.Sprintf("\n\nNegative prompt: %s", options.NegativePrompt)
}
if options.Size != "" {
promptText += fmt.Sprintf("\n\nImage size: %s", options.Size)
}
// 构建请求的 parts支持参考图
parts := []GeminiPart{}
// 如果有参考图,先添加参考图
if len(options.ReferenceImages) > 0 {
for _, refImg := range options.ReferenceImages {
var base64Data string
var mimeType string
var err error
// 检查是否是 HTTP/HTTPS URL
if strings.HasPrefix(refImg, "http://") || strings.HasPrefix(refImg, "https://") {
// 下载图片并转换为 base64
base64Data, mimeType, err = downloadImageToBase64(refImg)
if err != nil {
continue
}
} else if strings.HasPrefix(refImg, "data:") {
// 如果是 data URI 格式,需要解析
// 格式: data:image/jpeg;base64,xxxxx
mimeType = "image/jpeg"
parts := []byte(refImg)
for i := 0; i < len(parts); i++ {
if parts[i] == ',' {
base64Data = refImg[i+1:]
// 提取 mime type
if i > 11 {
mimeTypeEnd := i
for j := 5; j < i; j++ {
if parts[j] == ';' {
mimeTypeEnd = j
break
}
}
mimeType = refImg[5:mimeTypeEnd]
}
break
}
}
} else {
// 假设已经是 base64 编码
base64Data = refImg
mimeType = "image/jpeg"
}
if base64Data != "" {
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
}
}
// 添加文本提示词
parts = append(parts, GeminiPart{
Text: promptText,
})
reqBody := GeminiImageRequest{
Contents: []struct {
Parts []GeminiPart `json:"parts"`
}{
{
Parts: parts,
},
},
GenerationConfig: struct {
ResponseModalities []string `json:"responseModalities"`
}{
ResponseModalities: []string{"IMAGE"},
},
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + c.Endpoint
endpoint = replaceModelPlaceholder(endpoint, model)
url := fmt.Sprintf("%s?key=%s", endpoint, c.APIKey)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
bodyStr := string(body)
if len(bodyStr) > 1000 {
bodyStr = fmt.Sprintf("%s ... %s", bodyStr[:500], bodyStr[len(bodyStr)-500:])
}
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, bodyStr)
}
var result GeminiImageResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
return nil, fmt.Errorf("no image generated in response")
}
base64Data := result.Candidates[0].Content.Parts[0].InlineData.Data
if base64Data == "" {
return nil, fmt.Errorf("no base64 image data in response")
}
dataURI := fmt.Sprintf("data:image/jpeg;base64,%s", base64Data)
return &ImageResult{
Status: "completed",
ImageURL: dataURI,
Completed: true,
Width: 1024,
Height: 1024,
}, nil
}
func (c *GeminiImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
return nil, fmt.Errorf("not supported for Gemini (synchronous generation)")
}
func replaceModelPlaceholder(endpoint, model string) string {
result := endpoint
if bytes.Contains([]byte(result), []byte("{model}")) {
result = string(bytes.ReplaceAll([]byte(result), []byte("{model}"), []byte(model)))
}
return result
}

93
pkg/image/image_client.go Normal file
View File

@@ -0,0 +1,93 @@
package image
type ImageClient interface {
GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error)
GetTaskStatus(taskID string) (*ImageResult, error)
}
type ImageResult struct {
TaskID string
Status string
ImageURL string
Width int
Height int
Error string
Completed bool
}
type ImageOptions struct {
NegativePrompt string
Size string
Quality string
Style string
Steps int
CfgScale float64
Seed int64
Model string
Width int
Height int
ReferenceImages []string // 参考图片URL列表
}
type ImageOption func(*ImageOptions)
func WithNegativePrompt(prompt string) ImageOption {
return func(o *ImageOptions) {
o.NegativePrompt = prompt
}
}
func WithSize(size string) ImageOption {
return func(o *ImageOptions) {
o.Size = size
}
}
func WithQuality(quality string) ImageOption {
return func(o *ImageOptions) {
o.Quality = quality
}
}
func WithStyle(style string) ImageOption {
return func(o *ImageOptions) {
o.Style = style
}
}
func WithSteps(steps int) ImageOption {
return func(o *ImageOptions) {
o.Steps = steps
}
}
func WithCfgScale(scale float64) ImageOption {
return func(o *ImageOptions) {
o.CfgScale = scale
}
}
func WithSeed(seed int64) ImageOption {
return func(o *ImageOptions) {
o.Seed = seed
}
}
func WithModel(model string) ImageOption {
return func(o *ImageOptions) {
o.Model = model
}
}
func WithDimensions(width, height int) ImageOption {
return func(o *ImageOptions) {
o.Width = width
o.Height = height
}
}
func WithReferenceImages(images []string) ImageOption {
return func(o *ImageOptions) {
o.ReferenceImages = images
}
}

View File

@@ -0,0 +1,128 @@
package image
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type OpenAIImageClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
HTTPClient *http.Client
}
type DALLERequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
N int `json:"n"`
Image []string `json:"image,omitempty"`
}
type DALLEResponse struct {
Created int64 `json:"created"`
Data []struct {
URL string `json:"url"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
} `json:"data"`
}
func NewOpenAIImageClient(baseURL, apiKey, model, endpoint string) *OpenAIImageClient {
if endpoint == "" {
endpoint = "/v1/images/generations"
}
return &OpenAIImageClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *OpenAIImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
options := &ImageOptions{
Size: "1920x1920",
Quality: "standard",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := DALLERequest{
Model: model,
Prompt: prompt,
Size: options.Size,
Quality: options.Quality,
N: 1,
Image: options.ReferenceImages,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
url := c.BaseURL + c.Endpoint
fmt.Printf("[OpenAI Image] Request URL: %s\n", url)
fmt.Printf("[OpenAI Image] Request Body: %s\n", string(jsonData))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
fmt.Printf("OpenAI API Response: %s\n", string(body))
var result DALLEResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no image generated, response: %s", string(body))
}
return &ImageResult{
Status: "completed",
ImageURL: result.Data[0].URL,
Completed: true,
}, nil
}
func (c *OpenAIImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
return nil, fmt.Errorf("not supported for OpenAI/DALL-E")
}

View File

@@ -0,0 +1,158 @@
package image
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type VolcEngineImageClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
QueryEndpoint string
HTTPClient *http.Client
}
type VolcEngineImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Image []string `json:"image,omitempty"`
SequentialImageGeneration string `json:"sequential_image_generation,omitempty"`
Size string `json:"size,omitempty"`
Watermark bool `json:"watermark,omitempty"`
}
type VolcEngineImageResponse struct {
Model string `json:"model"`
Created int64 `json:"created"`
Data []struct {
URL string `json:"url"`
Size string `json:"size"`
} `json:"data"`
Usage struct {
GeneratedImages int `json:"generated_images"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
Error interface{} `json:"error,omitempty"`
}
func NewVolcEngineImageClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *VolcEngineImageClient {
if endpoint == "" {
endpoint = "/api/v3/images/generations"
}
if queryEndpoint == "" {
queryEndpoint = endpoint
}
return &VolcEngineImageClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *VolcEngineImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
options := &ImageOptions{
Size: "1024x1024",
Quality: "standard",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
promptText := prompt
if options.NegativePrompt != "" {
promptText += fmt.Sprintf(". Negative: %s", options.NegativePrompt)
}
size := options.Size
if size == "" {
if model == "doubao-seedream-4-5-251128" {
size = "2K"
} else {
size = "1K"
}
}
reqBody := VolcEngineImageRequest{
Model: model,
Prompt: promptText,
Image: options.ReferenceImages,
SequentialImageGeneration: "disabled",
Size: size,
Watermark: false,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
url := c.BaseURL + c.Endpoint
fmt.Printf("[VolcEngine Image] Request URL: %s\n", url)
fmt.Printf("[VolcEngine Image] Request Body: %s\n", string(jsonData))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
fmt.Printf("VolcEngine Image API Response: %s\n", string(body))
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result VolcEngineImageResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error != nil {
return nil, fmt.Errorf("volcengine error: %v", result.Error)
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no image generated")
}
return &ImageResult{
Status: "completed",
ImageURL: result.Data[0].URL,
Completed: true,
}, nil
}
func (c *VolcEngineImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
return nil, fmt.Errorf("not supported for VolcEngine Seedream (synchronous generation)")
}

35
pkg/logger/logger.go Normal file
View File

@@ -0,0 +1,35 @@
package logger
import (
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
type Logger struct {
*zap.SugaredLogger
}
func NewLogger(debug bool) *Logger {
var config zap.Config
if debug {
config = zap.NewDevelopmentConfig()
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
// 在开发模式下,禁用时间戳和调用者信息,使输出更简洁
config.EncoderConfig.TimeKey = ""
config.EncoderConfig.CallerKey = ""
} else {
config = zap.NewProductionConfig()
config.EncoderConfig.TimeKey = "timestamp"
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
}
logger, err := config.Build()
if err != nil {
panic(err)
}
return &Logger{
SugaredLogger: logger.Sugar(),
}
}

119
pkg/response/response.go Normal file
View File

@@ -0,0 +1,119 @@
package response
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
)
type Response struct {
Success bool `json:"success"`
Data interface{} `json:"data,omitempty"`
Error *ErrorInfo `json:"error,omitempty"`
Message string `json:"message,omitempty"`
Timestamp string `json:"timestamp"`
}
type ErrorInfo struct {
Code string `json:"code"`
Message string `json:"message"`
Details interface{} `json:"details,omitempty"`
}
type PaginationData struct {
Items interface{} `json:"items"`
Pagination Pagination `json:"pagination"`
}
type Pagination struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Total int64 `json:"total"`
TotalPages int64 `json:"total_pages"`
}
func Success(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, Response{
Success: true,
Data: data,
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func SuccessWithMessage(c *gin.Context, message string, data interface{}) {
c.JSON(http.StatusOK, Response{
Success: true,
Data: data,
Message: message,
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func Created(c *gin.Context, data interface{}) {
c.JSON(http.StatusCreated, Response{
Success: true,
Data: data,
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func SuccessWithPagination(c *gin.Context, items interface{}, total int64, page int, pageSize int) {
totalPages := (total + int64(pageSize) - 1) / int64(pageSize)
c.JSON(http.StatusOK, Response{
Success: true,
Data: PaginationData{
Items: items,
Pagination: Pagination{
Page: page,
PageSize: pageSize,
Total: total,
TotalPages: totalPages,
},
},
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func Error(c *gin.Context, statusCode int, errCode string, message string) {
c.JSON(statusCode, Response{
Success: false,
Error: &ErrorInfo{
Code: errCode,
Message: message,
},
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func ErrorWithDetails(c *gin.Context, statusCode int, errCode string, message string, details interface{}) {
c.JSON(statusCode, Response{
Success: false,
Error: &ErrorInfo{
Code: errCode,
Message: message,
Details: details,
},
Timestamp: time.Now().UTC().Format(time.RFC3339),
})
}
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, "BAD_REQUEST", message)
}
func Unauthorized(c *gin.Context, message string) {
Error(c, http.StatusUnauthorized, "UNAUTHORIZED", message)
}
func Forbidden(c *gin.Context, message string) {
Error(c, http.StatusForbidden, "FORBIDDEN", message)
}
func NotFound(c *gin.Context, message string) {
Error(c, http.StatusNotFound, "NOT_FOUND", message)
}
func InternalError(c *gin.Context, message string) {
Error(c, http.StatusInternalServerError, "INTERNAL_ERROR", message)
}

153
pkg/utils/json_parser.go Normal file
View File

@@ -0,0 +1,153 @@
package utils
import (
"encoding/json"
"fmt"
"regexp"
"strings"
)
// SafeParseAIJSON 安全地解析AI返回的JSON处理常见的格式问题
// 包括:
// 1. 移除Markdown代码块标记
// 2. 提取JSON对象
// 3. 清理多余的空白和换行
// 4. 尝试修复截断的JSON
// 5. 提供详细的错误信息
func SafeParseAIJSON(aiResponse string, v interface{}) error {
if aiResponse == "" {
return fmt.Errorf("AI返回内容为空")
}
// 1. 移除可能的Markdown代码块标记
cleaned := strings.TrimSpace(aiResponse)
cleaned = regexp.MustCompile("(?m)^```json\\s*").ReplaceAllString(cleaned, "")
cleaned = regexp.MustCompile("(?m)^```\\s*").ReplaceAllString(cleaned, "")
cleaned = strings.TrimSpace(cleaned)
// 2. 提取JSON对象 (查找第一个 { 到最后一个 })
jsonRegex := regexp.MustCompile(`(?s)\{.*\}`)
jsonMatch := jsonRegex.FindString(cleaned)
if jsonMatch == "" {
return fmt.Errorf("响应中未找到有效的JSON对象原始响应: %s", truncateString(aiResponse, 200))
}
// 3. 尝试解析JSON
err := json.Unmarshal([]byte(jsonMatch), v)
if err == nil {
return nil // 解析成功
}
// 4. 如果解析失败尝试修复截断的JSON
fixedJSON := attemptJSONRepair(jsonMatch)
if fixedJSON != jsonMatch {
if err := json.Unmarshal([]byte(fixedJSON), v); err == nil {
return nil // 修复后解析成功
}
}
// 5. 提供详细的错误上下文
if jsonErr, ok := err.(*json.SyntaxError); ok {
errorPos := int(jsonErr.Offset)
start := maxInt(0, errorPos-100)
end := minInt(len(jsonMatch), errorPos+100)
context := jsonMatch[start:end]
marker := strings.Repeat(" ", errorPos-start) + "^"
return fmt.Errorf(
"JSON解析失败: %s\n错误位置附近:\n%s\n%s",
jsonErr.Error(),
context,
marker,
)
}
return fmt.Errorf("JSON解析失败: %w\n原始响应: %s", err, truncateString(jsonMatch, 300))
}
// attemptJSONRepair 尝试修复常见的JSON问题
func attemptJSONRepair(jsonStr string) string {
// 1. 处理未闭合的字符串
// 如果最后一个字符不是 },尝试补全
trimmed := strings.TrimSpace(jsonStr)
// 2. 检查是否有未闭合的引号
if strings.Count(trimmed, `"`)%2 != 0 {
// 有奇数个引号,尝试补全最后一个引号
trimmed += `"`
}
// 3. 统计括号
openBraces := strings.Count(trimmed, "{")
closeBraces := strings.Count(trimmed, "}")
openBrackets := strings.Count(trimmed, "[")
closeBrackets := strings.Count(trimmed, "]")
// 4. 补全未闭合的数组
for i := 0; i < openBrackets-closeBrackets; i++ {
trimmed += "]"
}
// 5. 补全未闭合的对象
for i := 0; i < openBraces-closeBraces; i++ {
trimmed += "}"
}
return trimmed
}
// ExtractJSONFromText 从文本中提取JSON对象或数组
func ExtractJSONFromText(text string) string {
text = strings.TrimSpace(text)
// 移除Markdown代码块
text = regexp.MustCompile("(?m)^```json\\s*").ReplaceAllString(text, "")
text = regexp.MustCompile("(?m)^```\\s*").ReplaceAllString(text, "")
text = strings.TrimSpace(text)
// 查找JSON对象
if idx := strings.Index(text, "{"); idx != -1 {
if lastIdx := strings.LastIndex(text, "}"); lastIdx != -1 && lastIdx > idx {
return text[idx : lastIdx+1]
}
}
// 查找JSON数组
if idx := strings.Index(text, "["); idx != -1 {
if lastIdx := strings.LastIndex(text, "]"); lastIdx != -1 && lastIdx > idx {
return text[idx : lastIdx+1]
}
}
return text
}
// ValidateJSON 验证JSON字符串是否有效
func ValidateJSON(jsonStr string) error {
var js json.RawMessage
return json.Unmarshal([]byte(jsonStr), &js)
}
// Helper functions
func truncateString(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "..."
}
func maxInt(a, b int) int {
if a > b {
return a
}
return b
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}

View File

@@ -0,0 +1,419 @@
package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// ChatfireClient Chatfire 视频生成客户端
type ChatfireClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
QueryEndpoint string
HTTPClient *http.Client
}
type ChatfireRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
ImageURL string `json:"image_url,omitempty"`
Duration int `json:"duration,omitempty"`
Size string `json:"size,omitempty"`
}
// ChatfireSoraRequest Sora 模型请求格式
type ChatfireSoraRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Seconds string `json:"seconds,omitempty"`
Size string `json:"size,omitempty"`
InputReference string `json:"input_reference,omitempty"`
}
// ChatfireDoubaoRequest 豆包/火山模型请求格式
type ChatfireDoubaoRequest struct {
Model string `json:"model"`
Content []struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL map[string]interface{} `json:"image_url,omitempty"`
Role string `json:"role,omitempty"`
} `json:"content"`
}
type ChatfireResponse struct {
ID string `json:"id"`
TaskID string `json:"task_id,omitempty"`
Status string `json:"status,omitempty"`
Error json.RawMessage `json:"error,omitempty"`
Data struct {
ID string `json:"id,omitempty"`
Status string `json:"status,omitempty"`
VideoURL string `json:"video_url,omitempty"`
} `json:"data,omitempty"`
}
type ChatfireTaskResponse struct {
ID string `json:"id,omitempty"`
TaskID string `json:"task_id,omitempty"`
Status string `json:"status,omitempty"`
VideoURL string `json:"video_url,omitempty"`
Error json.RawMessage `json:"error,omitempty"`
Data struct {
ID string `json:"id,omitempty"`
Status string `json:"status,omitempty"`
VideoURL string `json:"video_url,omitempty"`
} `json:"data,omitempty"`
Content struct {
VideoURL string `json:"video_url,omitempty"`
} `json:"content,omitempty"`
}
// getErrorMessage 从 error 字段提取错误信息(支持字符串或对象)
func getErrorMessage(errorData json.RawMessage) string {
if len(errorData) == 0 {
return ""
}
// 尝试解析为字符串
var errStr string
if err := json.Unmarshal(errorData, &errStr); err == nil {
return errStr
}
// 尝试解析为对象
var errObj struct {
Message string `json:"message"`
Code string `json:"code"`
}
if err := json.Unmarshal(errorData, &errObj); err == nil {
if errObj.Message != "" {
return errObj.Message
}
}
// 返回原始 JSON 字符串
return string(errorData)
}
func NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *ChatfireClient {
if endpoint == "" {
endpoint = "/video/generations"
}
if queryEndpoint == "" {
queryEndpoint = "/video/task/{taskId}"
}
return &ChatfireClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
HTTPClient: &http.Client{
Timeout: 300 * time.Second,
},
}
}
func (c *ChatfireClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 5,
AspectRatio: "16:9",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
// 根据模型名称选择请求格式
var jsonData []byte
var err error
if strings.Contains(model, "doubao") || strings.Contains(model, "seedance") {
// 豆包/火山格式
reqBody := ChatfireDoubaoRequest{
Model: model,
}
// 构建prompt文本包含duration和ratio参数
promptText := prompt
if options.AspectRatio != "" {
promptText += fmt.Sprintf(" --ratio %s", options.AspectRatio)
}
if options.Duration > 0 {
promptText += fmt.Sprintf(" --dur %d", options.Duration)
}
// 添加文本内容
reqBody.Content = append(reqBody.Content, struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL map[string]interface{} `json:"image_url,omitempty"`
Role string `json:"role,omitempty"`
}{Type: "text", Text: promptText})
// 处理不同的图片模式
// 1. 组图模式多个reference_image
if len(options.ReferenceImageURLs) > 0 {
for _, refURL := range options.ReferenceImageURLs {
reqBody.Content = append(reqBody.Content, struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL map[string]interface{} `json:"image_url,omitempty"`
Role string `json:"role,omitempty"`
}{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": refURL,
},
Role: "reference_image",
})
}
} else if options.FirstFrameURL != "" && options.LastFrameURL != "" {
// 2. 首尾帧模式
reqBody.Content = append(reqBody.Content, struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL map[string]interface{} `json:"image_url,omitempty"`
Role string `json:"role,omitempty"`
}{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": options.FirstFrameURL,
},
Role: "first_frame",
})
reqBody.Content = append(reqBody.Content, struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL map[string]interface{} `json:"image_url,omitempty"`
Role string `json:"role,omitempty"`
}{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": options.LastFrameURL,
},
Role: "last_frame",
})
} else if imageURL != "" {
// 3. 单图模式(默认)
reqBody.Content = append(reqBody.Content, struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL map[string]interface{} `json:"image_url,omitempty"`
Role string `json:"role,omitempty"`
}{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": imageURL,
},
// 单图模式不需要role
})
} else if options.FirstFrameURL != "" {
// 4. 只有首帧
reqBody.Content = append(reqBody.Content, struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL map[string]interface{} `json:"image_url,omitempty"`
Role string `json:"role,omitempty"`
}{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": options.FirstFrameURL,
},
Role: "first_frame",
})
}
jsonData, err = json.Marshal(reqBody)
} else if strings.Contains(model, "sora") {
// Sora 格式
seconds := fmt.Sprintf("%d", options.Duration)
size := options.AspectRatio
if size == "16:9" {
size = "1280x720"
} else if size == "9:16" {
size = "720x1280"
}
reqBody := ChatfireSoraRequest{
Model: model,
Prompt: prompt,
Seconds: seconds,
Size: size,
InputReference: imageURL,
}
jsonData, err = json.Marshal(reqBody)
} else {
// 默认格式
reqBody := ChatfireRequest{
Model: model,
Prompt: prompt,
ImageURL: imageURL,
Duration: options.Duration,
Size: options.AspectRatio,
}
jsonData, err = json.Marshal(reqBody)
}
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + c.Endpoint
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
// 调试日志:打印响应内容
fmt.Printf("[Chatfire] Response body: %s\n", string(body))
var result ChatfireResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
}
// 优先使用 id 字段,其次使用 task_id
taskID := result.ID
if taskID == "" {
taskID = result.TaskID
}
// 如果有 data 嵌套,优先使用 data 中的值
if result.Data.ID != "" {
taskID = result.Data.ID
}
status := result.Status
if status == "" && result.Data.Status != "" {
status = result.Data.Status
}
fmt.Printf("[Chatfire] Parsed result - TaskID: %s, Status: %s\n", taskID, status)
if errMsg := getErrorMessage(result.Error); errMsg != "" {
return nil, fmt.Errorf("chatfire error: %s", errMsg)
}
videoResult := &VideoResult{
TaskID: taskID,
Status: status,
Completed: status == "completed" || status == "succeeded",
Duration: options.Duration,
}
return videoResult, nil
}
func (c *ChatfireClient) GetTaskStatus(taskID string) (*VideoResult, error) {
queryPath := c.QueryEndpoint
if strings.Contains(queryPath, "{taskId}") {
queryPath = strings.ReplaceAll(queryPath, "{taskId}", taskID)
} else if strings.Contains(queryPath, "{task_id}") {
queryPath = strings.ReplaceAll(queryPath, "{task_id}", taskID)
} else {
queryPath = queryPath + "/" + taskID
}
endpoint := c.BaseURL + queryPath
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
// 调试日志:打印响应内容
fmt.Printf("[Chatfire] GetTaskStatus Response body: %s\n", string(body))
var result ChatfireTaskResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
}
// 优先使用 id 字段,其次使用 task_id
responseTaskID := result.ID
if responseTaskID == "" {
responseTaskID = result.TaskID
}
// 如果有 data 嵌套,优先使用 data 中的值
if result.Data.ID != "" {
responseTaskID = result.Data.ID
}
status := result.Status
if status == "" && result.Data.Status != "" {
status = result.Data.Status
}
// 按优先级获取 video_urlVideoURL -> Data.VideoURL -> Content.VideoURL
videoURL := result.VideoURL
if videoURL == "" && result.Data.VideoURL != "" {
videoURL = result.Data.VideoURL
}
if videoURL == "" && result.Content.VideoURL != "" {
videoURL = result.Content.VideoURL
}
fmt.Printf("[Chatfire] Parsed result - TaskID: %s, Status: %s, VideoURL: %s\n", responseTaskID, status, videoURL)
videoResult := &VideoResult{
TaskID: responseTaskID,
Status: status,
Completed: status == "completed" || status == "succeeded",
}
if errMsg := getErrorMessage(result.Error); errMsg != "" {
videoResult.Error = errMsg
}
if videoURL != "" {
videoResult.VideoURL = videoURL
videoResult.Completed = true
}
return videoResult, nil
}

192
pkg/video/minimax_client.go Normal file
View File

@@ -0,0 +1,192 @@
package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
// MinimaxClient Minimax视频生成客户端
type MinimaxClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type MinimaxSubjectReference struct {
Type string `json:"type"`
Image []string `json:"image"`
}
type MinimaxRequest struct {
Prompt string `json:"prompt"`
FirstFrameImage string `json:"first_frame_image,omitempty"`
LastFrameImage string `json:"last_frame_image,omitempty"`
SubjectReference []MinimaxSubjectReference `json:"subject_reference,omitempty"`
Model string `json:"model"`
Duration int `json:"duration,omitempty"`
Resolution string `json:"resolution,omitempty"`
}
type MinimaxResponse struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
BaseResp struct {
StatusCode int `json:"status_code"`
StatusMsg string `json:"status_msg"`
} `json:"base_resp"`
Video struct {
URL string `json:"url"`
Duration int `json:"duration"`
} `json:"video"`
Error struct {
Code string `json:"code"`
Message string `json:"message"`
} `json:"error"`
}
func NewMinimaxClient(baseURL, apiKey, model string) *MinimaxClient {
return &MinimaxClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 300 * time.Second,
},
}
}
// GenerateVideo 生成视频(支持首尾帧和主体参考)
func (c *MinimaxClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 6,
Resolution: "1080P",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := MinimaxRequest{
Prompt: prompt,
Model: model,
Duration: options.Duration,
}
// 设置分辨率
if options.Resolution != "" {
reqBody.Resolution = options.Resolution
}
// 如果有首帧图片从imageURL或FirstFrameURL
if options.FirstFrameURL != "" {
reqBody.FirstFrameImage = options.FirstFrameURL
} else if imageURL != "" {
reqBody.FirstFrameImage = imageURL
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + "/v1/video_generation"
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result MinimaxResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error.Message != "" {
return nil, fmt.Errorf("minimax error: %s", result.Error.Message)
}
videoResult := &VideoResult{
TaskID: result.TaskID,
Status: result.Status,
Completed: result.Status == "completed",
Duration: result.Video.Duration,
}
if result.Video.URL != "" {
videoResult.VideoURL = result.Video.URL
videoResult.Completed = true
}
return videoResult, nil
}
func (c *MinimaxClient) GetTaskStatus(taskID string) (*VideoResult, error) {
endpoint := c.BaseURL + "/v1/video_generation/" + taskID
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result MinimaxResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
videoResult := &VideoResult{
TaskID: result.TaskID,
Status: result.Status,
Completed: result.Status == "completed",
Duration: result.Video.Duration,
}
if result.Error.Message != "" {
videoResult.Error = result.Error.Message
}
if result.Video.URL != "" {
videoResult.VideoURL = result.Video.URL
videoResult.Completed = true
}
return videoResult, nil
}

View File

@@ -0,0 +1,178 @@
package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"time"
)
type OpenAISoraClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type OpenAISoraResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Model string `json:"model"`
Status string `json:"status"`
Progress int `json:"progress"`
CreatedAt int64 `json:"created_at"`
CompletedAt int64 `json:"completed_at"`
Size string `json:"size"`
Seconds string `json:"seconds"`
Quality string `json:"quality"`
VideoURL string `json:"video_url"` // 直接的video_url字段
Video struct {
URL string `json:"url"`
} `json:"video"` // 嵌套的video.url字段兼容
Error struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error"`
}
func NewOpenAISoraClient(baseURL, apiKey, model string) *OpenAISoraClient {
return &OpenAISoraClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 300 * time.Second,
},
}
}
func (c *OpenAISoraClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 4,
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
writer.WriteField("model", model)
writer.WriteField("prompt", prompt)
if imageURL != "" {
writer.WriteField("input_reference", imageURL)
}
if options.Duration > 0 {
writer.WriteField("seconds", fmt.Sprintf("%d", options.Duration))
}
if options.Resolution != "" {
writer.WriteField("size", options.Resolution)
}
writer.Close()
endpoint := c.BaseURL + "/videos"
req, err := http.NewRequest("POST", endpoint, body)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(respBody))
}
var result OpenAISoraResponse
if err := json.Unmarshal(respBody, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error.Message != "" {
return nil, fmt.Errorf("openai error: %s", result.Error.Message)
}
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "completed",
}
// 优先使用video_url字段兼容video.url嵌套结构
if result.VideoURL != "" {
videoResult.VideoURL = result.VideoURL
} else if result.Video.URL != "" {
videoResult.VideoURL = result.Video.URL
}
return videoResult, nil
}
func (c *OpenAISoraClient) GetTaskStatus(taskID string) (*VideoResult, error) {
endpoint := c.BaseURL + "/videos/" + taskID
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result OpenAISoraResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "completed",
}
if result.Error.Message != "" {
videoResult.Error = result.Error.Message
}
// 优先使用video_url字段兼容video.url嵌套结构
if result.VideoURL != "" {
videoResult.VideoURL = result.VideoURL
} else if result.Video.URL != "" {
videoResult.VideoURL = result.Video.URL
}
return videoResult, nil
}

427
pkg/video/video_client.go Normal file
View File

@@ -0,0 +1,427 @@
package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type VideoClient interface {
GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error)
GetTaskStatus(taskID string) (*VideoResult, error)
}
type VideoResult struct {
TaskID string
Status string
VideoURL string
ThumbnailURL string
Duration int
Width int
Height int
Error string
Completed bool
}
type VideoOptions struct {
Model string
Duration int
FPS int
Resolution string
AspectRatio string
Style string
MotionLevel int
CameraMotion string
Seed int64
FirstFrameURL string
LastFrameURL string
ReferenceImageURLs []string
}
type VideoOption func(*VideoOptions)
func WithModel(model string) VideoOption {
return func(o *VideoOptions) {
o.Model = model
}
}
func WithDuration(duration int) VideoOption {
return func(o *VideoOptions) {
o.Duration = duration
}
}
func WithFPS(fps int) VideoOption {
return func(o *VideoOptions) {
o.FPS = fps
}
}
func WithResolution(resolution string) VideoOption {
return func(o *VideoOptions) {
o.Resolution = resolution
}
}
func WithAspectRatio(ratio string) VideoOption {
return func(o *VideoOptions) {
o.AspectRatio = ratio
}
}
func WithStyle(style string) VideoOption {
return func(o *VideoOptions) {
o.Style = style
}
}
func WithMotionLevel(level int) VideoOption {
return func(o *VideoOptions) {
o.MotionLevel = level
}
}
func WithCameraMotion(motion string) VideoOption {
return func(o *VideoOptions) {
o.CameraMotion = motion
}
}
func WithSeed(seed int64) VideoOption {
return func(o *VideoOptions) {
o.Seed = seed
}
}
func WithFirstFrame(url string) VideoOption {
return func(o *VideoOptions) {
o.FirstFrameURL = url
}
}
func WithLastFrame(url string) VideoOption {
return func(o *VideoOptions) {
o.LastFrameURL = url
}
}
func WithReferenceImages(urls []string) VideoOption {
return func(o *VideoOptions) {
o.ReferenceImageURLs = urls
}
}
type RunwayClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type RunwayRequest struct {
Model string `json:"model"`
PromptImage string `json:"prompt_image"`
PromptText string `json:"prompt_text"`
Duration int `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
type RunwayResponse struct {
ID string `json:"id"`
Status string `json:"status"`
Output struct {
URL string `json:"url"`
} `json:"output"`
Error string `json:"error,omitempty"`
}
func NewRunwayClient(baseURL, apiKey, model string) *RunwayClient {
return &RunwayClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 180 * time.Second,
},
}
}
func (c *RunwayClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 5,
AspectRatio: "16:9",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := RunwayRequest{
Model: model,
PromptImage: imageURL,
PromptText: prompt,
Duration: options.Duration,
AspectRatio: options.AspectRatio,
Seed: options.Seed,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + "/v1/video/generate"
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result RunwayResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error != "" {
return nil, fmt.Errorf("runway error: %s", result.Error)
}
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "succeeded",
}
if result.Output.URL != "" {
videoResult.VideoURL = result.Output.URL
}
return videoResult, nil
}
func (c *RunwayClient) GetTaskStatus(taskID string) (*VideoResult, error) {
endpoint := c.BaseURL + "/v1/video/status/" + taskID
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result RunwayResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "succeeded",
}
if result.Error != "" {
videoResult.Error = result.Error
}
if result.Output.URL != "" {
videoResult.VideoURL = result.Output.URL
}
return videoResult, nil
}
type PikaClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type PikaRequest struct {
Model string `json:"model"`
Image string `json:"image"`
Prompt string `json:"prompt"`
Duration int `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
Motion int `json:"motion,omitempty"`
CameraMotion string `json:"camera_motion,omitempty"`
Seed int64 `json:"seed,omitempty"`
}
type PikaResponse struct {
JobID string `json:"job_id"`
Status string `json:"status"`
Result struct {
VideoURL string `json:"video_url"`
} `json:"result"`
Error string `json:"error,omitempty"`
}
func NewPikaClient(baseURL, apiKey, model string) *PikaClient {
return &PikaClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 180 * time.Second,
},
}
}
func (c *PikaClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 3,
AspectRatio: "16:9",
MotionLevel: 50,
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := PikaRequest{
Model: model,
Image: imageURL,
Prompt: prompt,
Duration: options.Duration,
AspectRatio: options.AspectRatio,
Motion: options.MotionLevel,
CameraMotion: options.CameraMotion,
Seed: options.Seed,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + "/v1/video/generate"
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result PikaResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error != "" {
return nil, fmt.Errorf("pika error: %s", result.Error)
}
videoResult := &VideoResult{
TaskID: result.JobID,
Status: result.Status,
Completed: result.Status == "completed",
}
if result.Result.VideoURL != "" {
videoResult.VideoURL = result.Result.VideoURL
}
return videoResult, nil
}
func (c *PikaClient) GetTaskStatus(taskID string) (*VideoResult, error) {
endpoint := c.BaseURL + "/v1/video/status/" + taskID
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result PikaResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
videoResult := &VideoResult{
TaskID: result.JobID,
Status: result.Status,
Completed: result.Status == "completed",
}
if result.Error != "" {
videoResult.Error = result.Error
}
if result.Result.VideoURL != "" {
videoResult.VideoURL = result.Result.VideoURL
}
return videoResult, nil
}

View File

@@ -0,0 +1,290 @@
package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// VolcesArkClient 火山引擎ARK视频生成客户端
type VolcesArkClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
QueryEndpoint string
HTTPClient *http.Client
}
type VolcesArkContent struct {
Type string `json:"type"`
Text string `json:"text,omitempty"`
ImageURL map[string]interface{} `json:"image_url,omitempty"`
Role string `json:"role,omitempty"`
}
type VolcesArkRequest struct {
Model string `json:"model"`
Content []VolcesArkContent `json:"content"`
GenerateAudio bool `json:"generate_audio,omitempty"`
}
type VolcesArkResponse struct {
ID string `json:"id"`
Model string `json:"model"`
Status string `json:"status"`
Content struct {
VideoURL string `json:"video_url"`
} `json:"content"`
Usage struct {
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
Seed int `json:"seed"`
Resolution string `json:"resolution"`
Ratio string `json:"ratio"`
Duration int `json:"duration"`
FramesPerSecond int `json:"framespersecond"`
ServiceTier string `json:"service_tier"`
ExecutionExpiresAfter int `json:"execution_expires_after"`
GenerateAudio bool `json:"generate_audio"`
Error interface{} `json:"error,omitempty"`
}
func NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *VolcesArkClient {
if endpoint == "" {
endpoint = "/api/v3/contents/generations/tasks"
}
if queryEndpoint == "" {
queryEndpoint = endpoint
}
return &VolcesArkClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
HTTPClient: &http.Client{
Timeout: 300 * time.Second,
},
}
}
// GenerateVideo 生成视频(支持首帧、首尾帧、参考图等多种模式)
func (c *VolcesArkClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 5,
AspectRatio: "adaptive",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
// 构建prompt文本包含duration和ratio参数
promptText := prompt
if options.AspectRatio != "" {
promptText += fmt.Sprintf(" --ratio %s", options.AspectRatio)
}
if options.Duration > 0 {
promptText += fmt.Sprintf(" --dur %d", options.Duration)
}
content := []VolcesArkContent{
{
Type: "text",
Text: promptText,
},
}
// 处理不同的图片模式
// 1. 组图模式多个reference_image
if len(options.ReferenceImageURLs) > 0 {
for _, refURL := range options.ReferenceImageURLs {
content = append(content, VolcesArkContent{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": refURL,
},
Role: "reference_image",
})
}
} else if options.FirstFrameURL != "" && options.LastFrameURL != "" {
// 2. 首尾帧模式
content = append(content, VolcesArkContent{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": options.FirstFrameURL,
},
Role: "first_frame",
})
content = append(content, VolcesArkContent{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": options.LastFrameURL,
},
Role: "last_frame",
})
} else if imageURL != "" {
// 3. 单图模式(默认)
content = append(content, VolcesArkContent{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": imageURL,
},
// 单图模式不需要role
})
} else if options.FirstFrameURL != "" {
// 4. 只有首帧
content = append(content, VolcesArkContent{
Type: "image_url",
ImageURL: map[string]interface{}{
"url": options.FirstFrameURL,
},
Role: "first_frame",
})
}
// 只有 seedance-1-5-pro 模型支持 generate_audio 参数
generateAudio := false
if strings.Contains(strings.ToLower(model), "seedance-1-5-pro") {
generateAudio = true
}
reqBody := VolcesArkRequest{
Model: model,
Content: content,
GenerateAudio: generateAudio,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + c.Endpoint
fmt.Printf("[VolcesARK] Generating video - Endpoint: %s, FullURL: %s, Model: %s\n", c.Endpoint, endpoint, model)
fmt.Printf("[VolcesARK] Request body: %s\n", string(jsonData))
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
fmt.Printf("[VolcesARK] Response status: %d, body: %s\n", resp.StatusCode, string(body))
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result VolcesArkResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
fmt.Printf("[VolcesARK] Video generation initiated - TaskID: %s, Status: %s\n", result.ID, result.Status)
if result.Error != nil {
errorMsg := fmt.Sprintf("%v", result.Error)
return nil, fmt.Errorf("volces error: %s", errorMsg)
}
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "completed" || result.Status == "succeeded",
Duration: result.Duration,
}
if result.Content.VideoURL != "" {
videoResult.VideoURL = result.Content.VideoURL
videoResult.Completed = true
}
return videoResult, nil
}
func (c *VolcesArkClient) GetTaskStatus(taskID string) (*VideoResult, error) {
// 替换占位符{taskId}、{task_id}或直接拼接
queryPath := c.QueryEndpoint
if strings.Contains(queryPath, "{taskId}") {
queryPath = strings.ReplaceAll(queryPath, "{taskId}", taskID)
} else if strings.Contains(queryPath, "{task_id}") {
queryPath = strings.ReplaceAll(queryPath, "{task_id}", taskID)
} else {
queryPath = queryPath + "/" + taskID
}
endpoint := c.BaseURL + queryPath
fmt.Printf("[VolcesARK] Querying task status - TaskID: %s, QueryEndpoint: %s, FullURL: %s\n", taskID, c.QueryEndpoint, endpoint)
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
fmt.Printf("[VolcesARK] Response body: %s\n", string(body))
var result VolcesArkResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
fmt.Printf("[VolcesARK] Parsed result - ID: %s, Status: %s, VideoURL: %s\n", result.ID, result.Status, result.Content.VideoURL)
videoResult := &VideoResult{
TaskID: result.ID,
Status: result.Status,
Completed: result.Status == "completed" || result.Status == "succeeded",
Duration: result.Duration,
}
if result.Error != nil {
videoResult.Error = fmt.Sprintf("%v", result.Error)
}
if result.Content.VideoURL != "" {
videoResult.VideoURL = result.Content.VideoURL
videoResult.Completed = true
}
return videoResult, nil
}