create project
This commit is contained in:
7
pkg/ai/client.go
Normal file
7
pkg/ai/client.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package ai
|
||||
|
||||
// AIClient 定义文本生成客户端接口
|
||||
type AIClient interface {
|
||||
GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error)
|
||||
TestConnection() error
|
||||
}
|
||||
195
pkg/ai/gemini_client.go
Normal file
195
pkg/ai/gemini_client.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type GeminiClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Endpoint string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type GeminiTextRequest struct {
|
||||
Contents []GeminiContent `json:"contents"`
|
||||
SystemInstruction *GeminiInstruction `json:"systemInstruction,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiContent struct {
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type GeminiInstruction struct {
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}
|
||||
|
||||
type GeminiTextResponse struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
Role string `json:"role"`
|
||||
} `json:"content"`
|
||||
FinishReason string `json:"finishReason"`
|
||||
Index int `json:"index"`
|
||||
SafetyRatings []struct {
|
||||
Category string `json:"category"`
|
||||
Probability string `json:"probability"`
|
||||
} `json:"safetyRatings"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
func NewGeminiClient(baseURL, apiKey, model, endpoint string) *GeminiClient {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
if endpoint == "" {
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
}
|
||||
if model == "" {
|
||||
model = "gemini-3-pro"
|
||||
}
|
||||
return &GeminiClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
Endpoint: endpoint,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GeminiClient) GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error) {
|
||||
model := c.Model
|
||||
|
||||
// 构建请求体
|
||||
reqBody := GeminiTextRequest{
|
||||
Contents: []GeminiContent{
|
||||
{
|
||||
Parts: []GeminiPart{{Text: prompt}},
|
||||
Role: "user",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 使用 systemInstruction 字段处理系统提示
|
||||
if systemPrompt != "" {
|
||||
reqBody.SystemInstruction = &GeminiInstruction{
|
||||
Parts: []GeminiPart{{Text: systemPrompt}},
|
||||
}
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini: Failed to marshal request: %v\n", err)
|
||||
return "", fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
// 替换端点中的 {model} 占位符
|
||||
endpoint := c.BaseURL + c.Endpoint
|
||||
endpoint = strings.ReplaceAll(endpoint, "{model}", model)
|
||||
url := fmt.Sprintf("%s?key=%s", endpoint, c.APIKey)
|
||||
|
||||
// 打印请求信息(隐藏 API Key)
|
||||
safeURL := strings.Replace(url, c.APIKey, "***", 1)
|
||||
fmt.Printf("Gemini: Sending request to: %s\n", safeURL)
|
||||
requestPreview := string(jsonData)
|
||||
if len(jsonData) > 300 {
|
||||
requestPreview = string(jsonData[:300]) + "..."
|
||||
}
|
||||
fmt.Printf("Gemini: Request body: %s\n", requestPreview)
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini: Failed to create request: %v\n", err)
|
||||
return "", fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
fmt.Printf("Gemini: Executing HTTP request...\n")
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini: HTTP request failed: %v\n", err)
|
||||
return "", fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
fmt.Printf("Gemini: Received response with status: %d\n", resp.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini: Failed to read response body: %v\n", err)
|
||||
return "", fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
fmt.Printf("Gemini: API error (status %d): %s\n", resp.StatusCode, string(body))
|
||||
return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 打印响应体用于调试
|
||||
bodyPreview := string(body)
|
||||
if len(body) > 500 {
|
||||
bodyPreview = string(body[:500]) + "..."
|
||||
}
|
||||
fmt.Printf("Gemini: Response body: %s\n", bodyPreview)
|
||||
|
||||
var result GeminiTextResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
errorPreview := string(body)
|
||||
if len(body) > 200 {
|
||||
errorPreview = string(body[:200])
|
||||
}
|
||||
fmt.Printf("Gemini: Failed to parse response: %v\n", err)
|
||||
return "", fmt.Errorf("parse response: %w, body preview: %s", err, errorPreview)
|
||||
}
|
||||
|
||||
fmt.Printf("Gemini: Successfully parsed response, candidates count: %d\n", len(result.Candidates))
|
||||
|
||||
if len(result.Candidates) == 0 {
|
||||
fmt.Printf("Gemini: No candidates in response\n")
|
||||
return "", fmt.Errorf("no candidates in response")
|
||||
}
|
||||
|
||||
if len(result.Candidates[0].Content.Parts) == 0 {
|
||||
fmt.Printf("Gemini: No parts in first candidate\n")
|
||||
return "", fmt.Errorf("no parts in response")
|
||||
}
|
||||
|
||||
responseText := result.Candidates[0].Content.Parts[0].Text
|
||||
fmt.Printf("Gemini: Generated text: %s\n", responseText)
|
||||
|
||||
return responseText, nil
|
||||
}
|
||||
|
||||
func (c *GeminiClient) TestConnection() error {
|
||||
fmt.Printf("Gemini: TestConnection called with BaseURL=%s, Model=%s, Endpoint=%s\n", c.BaseURL, c.Model, c.Endpoint)
|
||||
_, err := c.GenerateText("Hello", "")
|
||||
if err != nil {
|
||||
fmt.Printf("Gemini: TestConnection failed: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("Gemini: TestConnection succeeded\n")
|
||||
}
|
||||
return err
|
||||
}
|
||||
227
pkg/ai/openai_client.go
Normal file
227
pkg/ai/openai_client.go
Normal file
@@ -0,0 +1,227 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpenAIClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Endpoint string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []ChatMessage `json:"messages"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func NewOpenAIClient(baseURL, apiKey, model, endpoint string) *OpenAIClient {
|
||||
if endpoint == "" {
|
||||
endpoint = "/v1/chat/completions"
|
||||
}
|
||||
|
||||
return &OpenAIClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
Endpoint: endpoint,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*ChatCompletionRequest)) (*ChatCompletionResponse, error) {
|
||||
req := &ChatCompletionRequest{
|
||||
Model: c.Model,
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
for _, option := range options {
|
||||
option(req)
|
||||
}
|
||||
|
||||
return c.sendChatRequest(req)
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
|
||||
jsonData, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
fmt.Printf("OpenAI: Failed to marshal request: %v\n", err)
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := c.BaseURL + c.Endpoint
|
||||
|
||||
// 打印请求信息
|
||||
fmt.Printf("OpenAI: Sending request to: %s\n", url)
|
||||
fmt.Printf("OpenAI: BaseURL=%s, Endpoint=%s, Model=%s\n", c.BaseURL, c.Endpoint, c.Model)
|
||||
requestPreview := string(jsonData)
|
||||
if len(jsonData) > 300 {
|
||||
requestPreview = string(jsonData[:300]) + "..."
|
||||
}
|
||||
fmt.Printf("OpenAI: Request body: %s\n", requestPreview)
|
||||
|
||||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
fmt.Printf("OpenAI: Failed to create request: %v\n", err)
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
fmt.Printf("OpenAI: Executing HTTP request...\n")
|
||||
resp, err := c.HTTPClient.Do(httpReq)
|
||||
if err != nil {
|
||||
fmt.Printf("OpenAI: HTTP request failed: %v\n", err)
|
||||
return nil, fmt.Errorf("failed to send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
fmt.Printf("OpenAI: Received response with status: %d\n", resp.StatusCode)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
fmt.Printf("OpenAI: Failed to read response body: %v\n", err)
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
fmt.Printf("OpenAI: API error (status %d): %s\n", resp.StatusCode, string(body))
|
||||
var errResp ErrorResponse
|
||||
if err := json.Unmarshal(body, &errResp); err != nil {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
return nil, fmt.Errorf("API error: %s", errResp.Error.Message)
|
||||
}
|
||||
|
||||
// 打印响应体用于调试
|
||||
bodyPreview := string(body)
|
||||
if len(body) > 500 {
|
||||
bodyPreview = string(body[:500]) + "..."
|
||||
}
|
||||
fmt.Printf("OpenAI: Response body: %s\n", bodyPreview)
|
||||
|
||||
var chatResp ChatCompletionResponse
|
||||
if err := json.Unmarshal(body, &chatResp); err != nil {
|
||||
errorPreview := string(body)
|
||||
if len(body) > 200 {
|
||||
errorPreview = string(body[:200])
|
||||
}
|
||||
fmt.Printf("OpenAI: Failed to parse response: %v\n", err)
|
||||
return nil, fmt.Errorf("failed to unmarshal response: %w, body preview: %s", err, errorPreview)
|
||||
}
|
||||
|
||||
fmt.Printf("OpenAI: Successfully parsed response, choices count: %d\n", len(chatResp.Choices))
|
||||
|
||||
return &chatResp, nil
|
||||
}
|
||||
|
||||
func WithTemperature(temp float64) func(*ChatCompletionRequest) {
|
||||
return func(req *ChatCompletionRequest) {
|
||||
req.Temperature = temp
|
||||
}
|
||||
}
|
||||
|
||||
func WithMaxTokens(tokens int) func(*ChatCompletionRequest) {
|
||||
return func(req *ChatCompletionRequest) {
|
||||
req.MaxTokens = tokens
|
||||
}
|
||||
}
|
||||
|
||||
func WithTopP(topP float64) func(*ChatCompletionRequest) {
|
||||
return func(req *ChatCompletionRequest) {
|
||||
req.TopP = topP
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error) {
|
||||
messages := []ChatMessage{}
|
||||
|
||||
if systemPrompt != "" {
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "system",
|
||||
Content: systemPrompt,
|
||||
})
|
||||
}
|
||||
|
||||
messages = append(messages, ChatMessage{
|
||||
Role: "user",
|
||||
Content: prompt,
|
||||
})
|
||||
|
||||
resp, err := c.ChatCompletion(messages, options...)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return "", fmt.Errorf("no response from API")
|
||||
}
|
||||
|
||||
return resp.Choices[0].Message.Content, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIClient) TestConnection() error {
|
||||
fmt.Printf("OpenAI: TestConnection called with BaseURL=%s, Endpoint=%s, Model=%s\n", c.BaseURL, c.Endpoint, c.Model)
|
||||
|
||||
messages := []ChatMessage{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
}
|
||||
|
||||
_, err := c.ChatCompletion(messages, WithMaxTokens(10))
|
||||
if err != nil {
|
||||
fmt.Printf("OpenAI: TestConnection failed: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("OpenAI: TestConnection succeeded\n")
|
||||
}
|
||||
return err
|
||||
}
|
||||
89
pkg/config/config.go
Normal file
89
pkg/config/config.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
App AppConfig `mapstructure:"app"`
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Storage StorageConfig `mapstructure:"storage"`
|
||||
AI AIConfig `mapstructure:"ai"`
|
||||
}
|
||||
|
||||
type AppConfig struct {
|
||||
Name string `mapstructure:"name"`
|
||||
Version string `mapstructure:"version"`
|
||||
Debug bool `mapstructure:"debug"`
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
Port int `mapstructure:"port"`
|
||||
Host string `mapstructure:"host"`
|
||||
CORSOrigins []string `mapstructure:"cors_origins"`
|
||||
ReadTimeout int `mapstructure:"read_timeout"`
|
||||
WriteTimeout int `mapstructure:"write_timeout"`
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Type string `mapstructure:"type"` // sqlite, mysql
|
||||
Path string `mapstructure:"path"` // SQLite数据库文件路径
|
||||
Host string `mapstructure:"host"`
|
||||
Port int `mapstructure:"port"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
Database string `mapstructure:"database"`
|
||||
Charset string `mapstructure:"charset"`
|
||||
MaxIdle int `mapstructure:"max_idle"`
|
||||
MaxOpen int `mapstructure:"max_open"`
|
||||
}
|
||||
|
||||
type StorageConfig struct {
|
||||
Type string `mapstructure:"type"` // local, minio
|
||||
LocalPath string `mapstructure:"local_path"` // 本地存储路径
|
||||
BaseURL string `mapstructure:"base_url"` // 访问URL前缀
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
DefaultTextProvider string `mapstructure:"default_text_provider"`
|
||||
DefaultImageProvider string `mapstructure:"default_image_provider"`
|
||||
DefaultVideoProvider string `mapstructure:"default_video_provider"`
|
||||
}
|
||||
|
||||
func LoadConfig() (*Config, error) {
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath("./configs")
|
||||
viper.AddConfigPath(".")
|
||||
|
||||
viper.AutomaticEnv()
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
return nil, fmt.Errorf("failed to read config: %w", err)
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := viper.Unmarshal(&config); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (c *DatabaseConfig) DSN() string {
|
||||
if c.Type == "sqlite" {
|
||||
return c.Path
|
||||
}
|
||||
// MySQL DSN
|
||||
return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=%s&parseTime=True&loc=Local",
|
||||
c.User,
|
||||
c.Password,
|
||||
c.Host,
|
||||
c.Port,
|
||||
c.Database,
|
||||
c.Charset,
|
||||
)
|
||||
}
|
||||
277
pkg/image/gemini_image_client.go
Normal file
277
pkg/image/gemini_image_client.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type GeminiImageClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Endpoint string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type GeminiImageRequest struct {
|
||||
Contents []struct {
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
} `json:"contents"`
|
||||
GenerationConfig struct {
|
||||
ResponseModalities []string `json:"responseModalities"`
|
||||
} `json:"generationConfig"`
|
||||
}
|
||||
|
||||
type GeminiPart struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiInlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"` // base64 编码的图片数据
|
||||
}
|
||||
|
||||
type GeminiImageResponse struct {
|
||||
Candidates []struct {
|
||||
Content struct {
|
||||
Parts []struct {
|
||||
InlineData struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Data string `json:"data"`
|
||||
} `json:"inlineData,omitempty"`
|
||||
Text string `json:"text,omitempty"`
|
||||
} `json:"parts"`
|
||||
} `json:"content"`
|
||||
} `json:"candidates"`
|
||||
UsageMetadata struct {
|
||||
PromptTokenCount int `json:"promptTokenCount"`
|
||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
||||
TotalTokenCount int `json:"totalTokenCount"`
|
||||
} `json:"usageMetadata"`
|
||||
}
|
||||
|
||||
// downloadImageToBase64 下载图片 URL 并转换为 base64
|
||||
func downloadImageToBase64(imageURL string) (string, string, error) {
|
||||
resp, err := http.Get(imageURL)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("download image: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", "", fmt.Errorf("download image failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
imageData, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", "", fmt.Errorf("read image data: %w", err)
|
||||
}
|
||||
|
||||
// 根据 Content-Type 确定 mimeType
|
||||
mimeType := resp.Header.Get("Content-Type")
|
||||
if mimeType == "" {
|
||||
mimeType = "image/jpeg"
|
||||
}
|
||||
|
||||
base64Data := base64.StdEncoding.EncodeToString(imageData)
|
||||
return base64Data, mimeType, nil
|
||||
}
|
||||
|
||||
func NewGeminiImageClient(baseURL, apiKey, model, endpoint string) *GeminiImageClient {
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
if endpoint == "" {
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
}
|
||||
if model == "" {
|
||||
model = "gemini-3-pro-image-preview"
|
||||
}
|
||||
return &GeminiImageClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
Endpoint: endpoint,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *GeminiImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
|
||||
options := &ImageOptions{
|
||||
Size: "1024x1024",
|
||||
Quality: "standard",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
promptText := prompt
|
||||
if options.NegativePrompt != "" {
|
||||
promptText += fmt.Sprintf("\n\nNegative prompt: %s", options.NegativePrompt)
|
||||
}
|
||||
if options.Size != "" {
|
||||
promptText += fmt.Sprintf("\n\nImage size: %s", options.Size)
|
||||
}
|
||||
|
||||
// 构建请求的 parts,支持参考图
|
||||
parts := []GeminiPart{}
|
||||
|
||||
// 如果有参考图,先添加参考图
|
||||
if len(options.ReferenceImages) > 0 {
|
||||
for _, refImg := range options.ReferenceImages {
|
||||
var base64Data string
|
||||
var mimeType string
|
||||
var err error
|
||||
|
||||
// 检查是否是 HTTP/HTTPS URL
|
||||
if strings.HasPrefix(refImg, "http://") || strings.HasPrefix(refImg, "https://") {
|
||||
// 下载图片并转换为 base64
|
||||
base64Data, mimeType, err = downloadImageToBase64(refImg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
} else if strings.HasPrefix(refImg, "data:") {
|
||||
// 如果是 data URI 格式,需要解析
|
||||
// 格式: data:image/jpeg;base64,xxxxx
|
||||
mimeType = "image/jpeg"
|
||||
parts := []byte(refImg)
|
||||
for i := 0; i < len(parts); i++ {
|
||||
if parts[i] == ',' {
|
||||
base64Data = refImg[i+1:]
|
||||
// 提取 mime type
|
||||
if i > 11 {
|
||||
mimeTypeEnd := i
|
||||
for j := 5; j < i; j++ {
|
||||
if parts[j] == ';' {
|
||||
mimeTypeEnd = j
|
||||
break
|
||||
}
|
||||
}
|
||||
mimeType = refImg[5:mimeTypeEnd]
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 假设已经是 base64 编码
|
||||
base64Data = refImg
|
||||
mimeType = "image/jpeg"
|
||||
}
|
||||
|
||||
if base64Data != "" {
|
||||
parts = append(parts, GeminiPart{
|
||||
InlineData: &GeminiInlineData{
|
||||
MimeType: mimeType,
|
||||
Data: base64Data,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 添加文本提示词
|
||||
parts = append(parts, GeminiPart{
|
||||
Text: promptText,
|
||||
})
|
||||
|
||||
reqBody := GeminiImageRequest{
|
||||
Contents: []struct {
|
||||
Parts []GeminiPart `json:"parts"`
|
||||
}{
|
||||
{
|
||||
Parts: parts,
|
||||
},
|
||||
},
|
||||
GenerationConfig: struct {
|
||||
ResponseModalities []string `json:"responseModalities"`
|
||||
}{
|
||||
ResponseModalities: []string{"IMAGE"},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + c.Endpoint
|
||||
endpoint = replaceModelPlaceholder(endpoint, model)
|
||||
url := fmt.Sprintf("%s?key=%s", endpoint, c.APIKey)
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyStr := string(body)
|
||||
if len(bodyStr) > 1000 {
|
||||
bodyStr = fmt.Sprintf("%s ... %s", bodyStr[:500], bodyStr[len(bodyStr)-500:])
|
||||
}
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, bodyStr)
|
||||
}
|
||||
|
||||
var result GeminiImageResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
|
||||
return nil, fmt.Errorf("no image generated in response")
|
||||
}
|
||||
|
||||
base64Data := result.Candidates[0].Content.Parts[0].InlineData.Data
|
||||
if base64Data == "" {
|
||||
return nil, fmt.Errorf("no base64 image data in response")
|
||||
}
|
||||
|
||||
dataURI := fmt.Sprintf("data:image/jpeg;base64,%s", base64Data)
|
||||
|
||||
return &ImageResult{
|
||||
Status: "completed",
|
||||
ImageURL: dataURI,
|
||||
Completed: true,
|
||||
Width: 1024,
|
||||
Height: 1024,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *GeminiImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
|
||||
return nil, fmt.Errorf("not supported for Gemini (synchronous generation)")
|
||||
}
|
||||
|
||||
func replaceModelPlaceholder(endpoint, model string) string {
|
||||
result := endpoint
|
||||
if bytes.Contains([]byte(result), []byte("{model}")) {
|
||||
result = string(bytes.ReplaceAll([]byte(result), []byte("{model}"), []byte(model)))
|
||||
}
|
||||
return result
|
||||
}
|
||||
93
pkg/image/image_client.go
Normal file
93
pkg/image/image_client.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package image
|
||||
|
||||
type ImageClient interface {
|
||||
GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error)
|
||||
GetTaskStatus(taskID string) (*ImageResult, error)
|
||||
}
|
||||
|
||||
type ImageResult struct {
|
||||
TaskID string
|
||||
Status string
|
||||
ImageURL string
|
||||
Width int
|
||||
Height int
|
||||
Error string
|
||||
Completed bool
|
||||
}
|
||||
|
||||
type ImageOptions struct {
|
||||
NegativePrompt string
|
||||
Size string
|
||||
Quality string
|
||||
Style string
|
||||
Steps int
|
||||
CfgScale float64
|
||||
Seed int64
|
||||
Model string
|
||||
Width int
|
||||
Height int
|
||||
ReferenceImages []string // 参考图片URL列表
|
||||
}
|
||||
|
||||
type ImageOption func(*ImageOptions)
|
||||
|
||||
func WithNegativePrompt(prompt string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.NegativePrompt = prompt
|
||||
}
|
||||
}
|
||||
|
||||
func WithSize(size string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Size = size
|
||||
}
|
||||
}
|
||||
|
||||
func WithQuality(quality string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Quality = quality
|
||||
}
|
||||
}
|
||||
|
||||
func WithStyle(style string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Style = style
|
||||
}
|
||||
}
|
||||
|
||||
func WithSteps(steps int) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Steps = steps
|
||||
}
|
||||
}
|
||||
|
||||
func WithCfgScale(scale float64) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.CfgScale = scale
|
||||
}
|
||||
}
|
||||
|
||||
func WithSeed(seed int64) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Seed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithModel(model string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithDimensions(width, height int) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.Width = width
|
||||
o.Height = height
|
||||
}
|
||||
}
|
||||
|
||||
func WithReferenceImages(images []string) ImageOption {
|
||||
return func(o *ImageOptions) {
|
||||
o.ReferenceImages = images
|
||||
}
|
||||
}
|
||||
128
pkg/image/openai_image_client.go
Normal file
128
pkg/image/openai_image_client.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpenAIImageClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Endpoint string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type DALLERequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
N int `json:"n"`
|
||||
Image []string `json:"image,omitempty"`
|
||||
}
|
||||
|
||||
type DALLEResponse struct {
|
||||
Created int64 `json:"created"`
|
||||
Data []struct {
|
||||
URL string `json:"url"`
|
||||
RevisedPrompt string `json:"revised_prompt,omitempty"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
func NewOpenAIImageClient(baseURL, apiKey, model, endpoint string) *OpenAIImageClient {
|
||||
if endpoint == "" {
|
||||
endpoint = "/v1/images/generations"
|
||||
}
|
||||
return &OpenAIImageClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
Endpoint: endpoint,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenAIImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
|
||||
options := &ImageOptions{
|
||||
Size: "1920x1920",
|
||||
Quality: "standard",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
reqBody := DALLERequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Size: options.Size,
|
||||
Quality: options.Quality,
|
||||
N: 1,
|
||||
Image: options.ReferenceImages,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := c.BaseURL + c.Endpoint
|
||||
fmt.Printf("[OpenAI Image] Request URL: %s\n", url)
|
||||
fmt.Printf("[OpenAI Image] Request Body: %s\n", string(jsonData))
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
fmt.Printf("OpenAI API Response: %s\n", string(body))
|
||||
|
||||
var result DALLEResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
|
||||
}
|
||||
|
||||
if len(result.Data) == 0 {
|
||||
return nil, fmt.Errorf("no image generated, response: %s", string(body))
|
||||
}
|
||||
|
||||
return &ImageResult{
|
||||
Status: "completed",
|
||||
ImageURL: result.Data[0].URL,
|
||||
Completed: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *OpenAIImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
|
||||
return nil, fmt.Errorf("not supported for OpenAI/DALL-E")
|
||||
}
|
||||
158
pkg/image/volcengine_image_client.go
Normal file
158
pkg/image/volcengine_image_client.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package image
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type VolcEngineImageClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Endpoint string
|
||||
QueryEndpoint string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type VolcEngineImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Image []string `json:"image,omitempty"`
|
||||
SequentialImageGeneration string `json:"sequential_image_generation,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Watermark bool `json:"watermark,omitempty"`
|
||||
}
|
||||
|
||||
type VolcEngineImageResponse struct {
|
||||
Model string `json:"model"`
|
||||
Created int64 `json:"created"`
|
||||
Data []struct {
|
||||
URL string `json:"url"`
|
||||
Size string `json:"size"`
|
||||
} `json:"data"`
|
||||
Usage struct {
|
||||
GeneratedImages int `json:"generated_images"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
Error interface{} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func NewVolcEngineImageClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *VolcEngineImageClient {
|
||||
if endpoint == "" {
|
||||
endpoint = "/api/v3/images/generations"
|
||||
}
|
||||
if queryEndpoint == "" {
|
||||
queryEndpoint = endpoint
|
||||
}
|
||||
return &VolcEngineImageClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
Endpoint: endpoint,
|
||||
QueryEndpoint: queryEndpoint,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 10 * time.Minute,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *VolcEngineImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
|
||||
options := &ImageOptions{
|
||||
Size: "1024x1024",
|
||||
Quality: "standard",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
promptText := prompt
|
||||
if options.NegativePrompt != "" {
|
||||
promptText += fmt.Sprintf(". Negative: %s", options.NegativePrompt)
|
||||
}
|
||||
|
||||
size := options.Size
|
||||
if size == "" {
|
||||
if model == "doubao-seedream-4-5-251128" {
|
||||
size = "2K"
|
||||
} else {
|
||||
size = "1K"
|
||||
}
|
||||
}
|
||||
|
||||
reqBody := VolcEngineImageRequest{
|
||||
Model: model,
|
||||
Prompt: promptText,
|
||||
Image: options.ReferenceImages,
|
||||
SequentialImageGeneration: "disabled",
|
||||
Size: size,
|
||||
Watermark: false,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := c.BaseURL + c.Endpoint
|
||||
fmt.Printf("[VolcEngine Image] Request URL: %s\n", url)
|
||||
fmt.Printf("[VolcEngine Image] Request Body: %s\n", string(jsonData))
|
||||
|
||||
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("VolcEngine Image API Response: %s\n", string(body))
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result VolcEngineImageResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("volcengine error: %v", result.Error)
|
||||
}
|
||||
|
||||
if len(result.Data) == 0 {
|
||||
return nil, fmt.Errorf("no image generated")
|
||||
}
|
||||
|
||||
return &ImageResult{
|
||||
Status: "completed",
|
||||
ImageURL: result.Data[0].URL,
|
||||
Completed: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *VolcEngineImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
|
||||
return nil, fmt.Errorf("not supported for VolcEngine Seedream (synchronous generation)")
|
||||
}
|
||||
35
pkg/logger/logger.go
Normal file
35
pkg/logger/logger.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
*zap.SugaredLogger
|
||||
}
|
||||
|
||||
func NewLogger(debug bool) *Logger {
|
||||
var config zap.Config
|
||||
|
||||
if debug {
|
||||
config = zap.NewDevelopmentConfig()
|
||||
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||
// 在开发模式下,禁用时间戳和调用者信息,使输出更简洁
|
||||
config.EncoderConfig.TimeKey = ""
|
||||
config.EncoderConfig.CallerKey = ""
|
||||
} else {
|
||||
config = zap.NewProductionConfig()
|
||||
config.EncoderConfig.TimeKey = "timestamp"
|
||||
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
}
|
||||
|
||||
logger, err := config.Build()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &Logger{
|
||||
SugaredLogger: logger.Sugar(),
|
||||
}
|
||||
}
|
||||
119
pkg/response/response.go
Normal file
119
pkg/response/response.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Response struct {
|
||||
Success bool `json:"success"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Error *ErrorInfo `json:"error,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Timestamp string `json:"timestamp"`
|
||||
}
|
||||
|
||||
type ErrorInfo struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details interface{} `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
type PaginationData struct {
|
||||
Items interface{} `json:"items"`
|
||||
Pagination Pagination `json:"pagination"`
|
||||
}
|
||||
|
||||
type Pagination struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Total int64 `json:"total"`
|
||||
TotalPages int64 `json:"total_pages"`
|
||||
}
|
||||
|
||||
func Success(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func SuccessWithMessage(c *gin.Context, message string, data interface{}) {
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Message: message,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func Created(c *gin.Context, data interface{}) {
|
||||
c.JSON(http.StatusCreated, Response{
|
||||
Success: true,
|
||||
Data: data,
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func SuccessWithPagination(c *gin.Context, items interface{}, total int64, page int, pageSize int) {
|
||||
totalPages := (total + int64(pageSize) - 1) / int64(pageSize)
|
||||
c.JSON(http.StatusOK, Response{
|
||||
Success: true,
|
||||
Data: PaginationData{
|
||||
Items: items,
|
||||
Pagination: Pagination{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Total: total,
|
||||
TotalPages: totalPages,
|
||||
},
|
||||
},
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func Error(c *gin.Context, statusCode int, errCode string, message string) {
|
||||
c.JSON(statusCode, Response{
|
||||
Success: false,
|
||||
Error: &ErrorInfo{
|
||||
Code: errCode,
|
||||
Message: message,
|
||||
},
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func ErrorWithDetails(c *gin.Context, statusCode int, errCode string, message string, details interface{}) {
|
||||
c.JSON(statusCode, Response{
|
||||
Success: false,
|
||||
Error: &ErrorInfo{
|
||||
Code: errCode,
|
||||
Message: message,
|
||||
Details: details,
|
||||
},
|
||||
Timestamp: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func BadRequest(c *gin.Context, message string) {
|
||||
Error(c, http.StatusBadRequest, "BAD_REQUEST", message)
|
||||
}
|
||||
|
||||
func Unauthorized(c *gin.Context, message string) {
|
||||
Error(c, http.StatusUnauthorized, "UNAUTHORIZED", message)
|
||||
}
|
||||
|
||||
func Forbidden(c *gin.Context, message string) {
|
||||
Error(c, http.StatusForbidden, "FORBIDDEN", message)
|
||||
}
|
||||
|
||||
func NotFound(c *gin.Context, message string) {
|
||||
Error(c, http.StatusNotFound, "NOT_FOUND", message)
|
||||
}
|
||||
|
||||
func InternalError(c *gin.Context, message string) {
|
||||
Error(c, http.StatusInternalServerError, "INTERNAL_ERROR", message)
|
||||
}
|
||||
153
pkg/utils/json_parser.go
Normal file
153
pkg/utils/json_parser.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SafeParseAIJSON 安全地解析AI返回的JSON,处理常见的格式问题
|
||||
// 包括:
|
||||
// 1. 移除Markdown代码块标记
|
||||
// 2. 提取JSON对象
|
||||
// 3. 清理多余的空白和换行
|
||||
// 4. 尝试修复截断的JSON
|
||||
// 5. 提供详细的错误信息
|
||||
func SafeParseAIJSON(aiResponse string, v interface{}) error {
|
||||
if aiResponse == "" {
|
||||
return fmt.Errorf("AI返回内容为空")
|
||||
}
|
||||
|
||||
// 1. 移除可能的Markdown代码块标记
|
||||
cleaned := strings.TrimSpace(aiResponse)
|
||||
cleaned = regexp.MustCompile("(?m)^```json\\s*").ReplaceAllString(cleaned, "")
|
||||
cleaned = regexp.MustCompile("(?m)^```\\s*").ReplaceAllString(cleaned, "")
|
||||
cleaned = strings.TrimSpace(cleaned)
|
||||
|
||||
// 2. 提取JSON对象 (查找第一个 { 到最后一个 })
|
||||
jsonRegex := regexp.MustCompile(`(?s)\{.*\}`)
|
||||
jsonMatch := jsonRegex.FindString(cleaned)
|
||||
|
||||
if jsonMatch == "" {
|
||||
return fmt.Errorf("响应中未找到有效的JSON对象,原始响应: %s", truncateString(aiResponse, 200))
|
||||
}
|
||||
|
||||
// 3. 尝试解析JSON
|
||||
err := json.Unmarshal([]byte(jsonMatch), v)
|
||||
if err == nil {
|
||||
return nil // 解析成功
|
||||
}
|
||||
|
||||
// 4. 如果解析失败,尝试修复截断的JSON
|
||||
fixedJSON := attemptJSONRepair(jsonMatch)
|
||||
if fixedJSON != jsonMatch {
|
||||
if err := json.Unmarshal([]byte(fixedJSON), v); err == nil {
|
||||
return nil // 修复后解析成功
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 提供详细的错误上下文
|
||||
if jsonErr, ok := err.(*json.SyntaxError); ok {
|
||||
errorPos := int(jsonErr.Offset)
|
||||
start := maxInt(0, errorPos-100)
|
||||
end := minInt(len(jsonMatch), errorPos+100)
|
||||
|
||||
context := jsonMatch[start:end]
|
||||
marker := strings.Repeat(" ", errorPos-start) + "^"
|
||||
|
||||
return fmt.Errorf(
|
||||
"JSON解析失败: %s\n错误位置附近:\n%s\n%s",
|
||||
jsonErr.Error(),
|
||||
context,
|
||||
marker,
|
||||
)
|
||||
}
|
||||
|
||||
return fmt.Errorf("JSON解析失败: %w\n原始响应: %s", err, truncateString(jsonMatch, 300))
|
||||
}
|
||||
|
||||
// attemptJSONRepair 尝试修复常见的JSON问题
|
||||
func attemptJSONRepair(jsonStr string) string {
|
||||
// 1. 处理未闭合的字符串
|
||||
// 如果最后一个字符不是 },尝试补全
|
||||
trimmed := strings.TrimSpace(jsonStr)
|
||||
|
||||
// 2. 检查是否有未闭合的引号
|
||||
if strings.Count(trimmed, `"`)%2 != 0 {
|
||||
// 有奇数个引号,尝试补全最后一个引号
|
||||
trimmed += `"`
|
||||
}
|
||||
|
||||
// 3. 统计括号
|
||||
openBraces := strings.Count(trimmed, "{")
|
||||
closeBraces := strings.Count(trimmed, "}")
|
||||
openBrackets := strings.Count(trimmed, "[")
|
||||
closeBrackets := strings.Count(trimmed, "]")
|
||||
|
||||
// 4. 补全未闭合的数组
|
||||
for i := 0; i < openBrackets-closeBrackets; i++ {
|
||||
trimmed += "]"
|
||||
}
|
||||
|
||||
// 5. 补全未闭合的对象
|
||||
for i := 0; i < openBraces-closeBraces; i++ {
|
||||
trimmed += "}"
|
||||
}
|
||||
|
||||
return trimmed
|
||||
}
|
||||
|
||||
// ExtractJSONFromText 从文本中提取JSON对象或数组
|
||||
func ExtractJSONFromText(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
|
||||
// 移除Markdown代码块
|
||||
text = regexp.MustCompile("(?m)^```json\\s*").ReplaceAllString(text, "")
|
||||
text = regexp.MustCompile("(?m)^```\\s*").ReplaceAllString(text, "")
|
||||
text = strings.TrimSpace(text)
|
||||
|
||||
// 查找JSON对象
|
||||
if idx := strings.Index(text, "{"); idx != -1 {
|
||||
if lastIdx := strings.LastIndex(text, "}"); lastIdx != -1 && lastIdx > idx {
|
||||
return text[idx : lastIdx+1]
|
||||
}
|
||||
}
|
||||
|
||||
// 查找JSON数组
|
||||
if idx := strings.Index(text, "["); idx != -1 {
|
||||
if lastIdx := strings.LastIndex(text, "]"); lastIdx != -1 && lastIdx > idx {
|
||||
return text[idx : lastIdx+1]
|
||||
}
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
|
||||
// ValidateJSON 验证JSON字符串是否有效
|
||||
func ValidateJSON(jsonStr string) error {
|
||||
var js json.RawMessage
|
||||
return json.Unmarshal([]byte(jsonStr), &js)
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func truncateString(s string, maxLen int) string {
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "..."
|
||||
}
|
||||
|
||||
func maxInt(a, b int) int {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func minInt(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
419
pkg/video/chatfire_client.go
Normal file
419
pkg/video/chatfire_client.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ChatfireClient Chatfire 视频生成客户端
|
||||
type ChatfireClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Endpoint string
|
||||
QueryEndpoint string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type ChatfireRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
ImageURL string `json:"image_url,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
}
|
||||
|
||||
// ChatfireSoraRequest Sora 模型请求格式
|
||||
type ChatfireSoraRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Seconds string `json:"seconds,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
InputReference string `json:"input_reference,omitempty"`
|
||||
}
|
||||
|
||||
// ChatfireDoubaoRequest 豆包/火山模型请求格式
|
||||
type ChatfireDoubaoRequest struct {
|
||||
Model string `json:"model"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL map[string]interface{} `json:"image_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
} `json:"content"`
|
||||
}
|
||||
|
||||
type ChatfireResponse struct {
|
||||
ID string `json:"id"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
Error json.RawMessage `json:"error,omitempty"`
|
||||
Data struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
VideoURL string `json:"video_url,omitempty"`
|
||||
} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
type ChatfireTaskResponse struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
TaskID string `json:"task_id,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
VideoURL string `json:"video_url,omitempty"`
|
||||
Error json.RawMessage `json:"error,omitempty"`
|
||||
Data struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
VideoURL string `json:"video_url,omitempty"`
|
||||
} `json:"data,omitempty"`
|
||||
Content struct {
|
||||
VideoURL string `json:"video_url,omitempty"`
|
||||
} `json:"content,omitempty"`
|
||||
}
|
||||
|
||||
// getErrorMessage 从 error 字段提取错误信息(支持字符串或对象)
|
||||
func getErrorMessage(errorData json.RawMessage) string {
|
||||
if len(errorData) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// 尝试解析为字符串
|
||||
var errStr string
|
||||
if err := json.Unmarshal(errorData, &errStr); err == nil {
|
||||
return errStr
|
||||
}
|
||||
|
||||
// 尝试解析为对象
|
||||
var errObj struct {
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code"`
|
||||
}
|
||||
if err := json.Unmarshal(errorData, &errObj); err == nil {
|
||||
if errObj.Message != "" {
|
||||
return errObj.Message
|
||||
}
|
||||
}
|
||||
|
||||
// 返回原始 JSON 字符串
|
||||
return string(errorData)
|
||||
}
|
||||
|
||||
func NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *ChatfireClient {
|
||||
if endpoint == "" {
|
||||
endpoint = "/video/generations"
|
||||
}
|
||||
if queryEndpoint == "" {
|
||||
queryEndpoint = "/video/task/{taskId}"
|
||||
}
|
||||
return &ChatfireClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
Endpoint: endpoint,
|
||||
QueryEndpoint: queryEndpoint,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 300 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatfireClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 5,
|
||||
AspectRatio: "16:9",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
// 根据模型名称选择请求格式
|
||||
var jsonData []byte
|
||||
var err error
|
||||
|
||||
if strings.Contains(model, "doubao") || strings.Contains(model, "seedance") {
|
||||
// 豆包/火山格式
|
||||
reqBody := ChatfireDoubaoRequest{
|
||||
Model: model,
|
||||
}
|
||||
|
||||
// 构建prompt文本(包含duration和ratio参数)
|
||||
promptText := prompt
|
||||
if options.AspectRatio != "" {
|
||||
promptText += fmt.Sprintf(" --ratio %s", options.AspectRatio)
|
||||
}
|
||||
if options.Duration > 0 {
|
||||
promptText += fmt.Sprintf(" --dur %d", options.Duration)
|
||||
}
|
||||
|
||||
// 添加文本内容
|
||||
reqBody.Content = append(reqBody.Content, struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL map[string]interface{} `json:"image_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}{Type: "text", Text: promptText})
|
||||
|
||||
// 处理不同的图片模式
|
||||
// 1. 组图模式(多个reference_image)
|
||||
if len(options.ReferenceImageURLs) > 0 {
|
||||
for _, refURL := range options.ReferenceImageURLs {
|
||||
reqBody.Content = append(reqBody.Content, struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL map[string]interface{} `json:"image_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": refURL,
|
||||
},
|
||||
Role: "reference_image",
|
||||
})
|
||||
}
|
||||
} else if options.FirstFrameURL != "" && options.LastFrameURL != "" {
|
||||
// 2. 首尾帧模式
|
||||
reqBody.Content = append(reqBody.Content, struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL map[string]interface{} `json:"image_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": options.FirstFrameURL,
|
||||
},
|
||||
Role: "first_frame",
|
||||
})
|
||||
reqBody.Content = append(reqBody.Content, struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL map[string]interface{} `json:"image_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": options.LastFrameURL,
|
||||
},
|
||||
Role: "last_frame",
|
||||
})
|
||||
} else if imageURL != "" {
|
||||
// 3. 单图模式(默认)
|
||||
reqBody.Content = append(reqBody.Content, struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL map[string]interface{} `json:"image_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": imageURL,
|
||||
},
|
||||
// 单图模式不需要role
|
||||
})
|
||||
} else if options.FirstFrameURL != "" {
|
||||
// 4. 只有首帧
|
||||
reqBody.Content = append(reqBody.Content, struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL map[string]interface{} `json:"image_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": options.FirstFrameURL,
|
||||
},
|
||||
Role: "first_frame",
|
||||
})
|
||||
}
|
||||
|
||||
jsonData, err = json.Marshal(reqBody)
|
||||
} else if strings.Contains(model, "sora") {
|
||||
// Sora 格式
|
||||
seconds := fmt.Sprintf("%d", options.Duration)
|
||||
size := options.AspectRatio
|
||||
if size == "16:9" {
|
||||
size = "1280x720"
|
||||
} else if size == "9:16" {
|
||||
size = "720x1280"
|
||||
}
|
||||
|
||||
reqBody := ChatfireSoraRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
Seconds: seconds,
|
||||
Size: size,
|
||||
InputReference: imageURL,
|
||||
}
|
||||
jsonData, err = json.Marshal(reqBody)
|
||||
} else {
|
||||
// 默认格式
|
||||
reqBody := ChatfireRequest{
|
||||
Model: model,
|
||||
Prompt: prompt,
|
||||
ImageURL: imageURL,
|
||||
Duration: options.Duration,
|
||||
Size: options.AspectRatio,
|
||||
}
|
||||
jsonData, err = json.Marshal(reqBody)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + c.Endpoint
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
// 调试日志:打印响应内容
|
||||
fmt.Printf("[Chatfire] Response body: %s\n", string(body))
|
||||
|
||||
var result ChatfireResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
|
||||
}
|
||||
|
||||
// 优先使用 id 字段,其次使用 task_id
|
||||
taskID := result.ID
|
||||
if taskID == "" {
|
||||
taskID = result.TaskID
|
||||
}
|
||||
|
||||
// 如果有 data 嵌套,优先使用 data 中的值
|
||||
if result.Data.ID != "" {
|
||||
taskID = result.Data.ID
|
||||
}
|
||||
|
||||
status := result.Status
|
||||
if status == "" && result.Data.Status != "" {
|
||||
status = result.Data.Status
|
||||
}
|
||||
|
||||
fmt.Printf("[Chatfire] Parsed result - TaskID: %s, Status: %s\n", taskID, status)
|
||||
|
||||
if errMsg := getErrorMessage(result.Error); errMsg != "" {
|
||||
return nil, fmt.Errorf("chatfire error: %s", errMsg)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: taskID,
|
||||
Status: status,
|
||||
Completed: status == "completed" || status == "succeeded",
|
||||
Duration: options.Duration,
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *ChatfireClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
queryPath := c.QueryEndpoint
|
||||
if strings.Contains(queryPath, "{taskId}") {
|
||||
queryPath = strings.ReplaceAll(queryPath, "{taskId}", taskID)
|
||||
} else if strings.Contains(queryPath, "{task_id}") {
|
||||
queryPath = strings.ReplaceAll(queryPath, "{task_id}", taskID)
|
||||
} else {
|
||||
queryPath = queryPath + "/" + taskID
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + queryPath
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
// 调试日志:打印响应内容
|
||||
fmt.Printf("[Chatfire] GetTaskStatus Response body: %s\n", string(body))
|
||||
|
||||
var result ChatfireTaskResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
|
||||
}
|
||||
|
||||
// 优先使用 id 字段,其次使用 task_id
|
||||
responseTaskID := result.ID
|
||||
if responseTaskID == "" {
|
||||
responseTaskID = result.TaskID
|
||||
}
|
||||
|
||||
// 如果有 data 嵌套,优先使用 data 中的值
|
||||
if result.Data.ID != "" {
|
||||
responseTaskID = result.Data.ID
|
||||
}
|
||||
|
||||
status := result.Status
|
||||
if status == "" && result.Data.Status != "" {
|
||||
status = result.Data.Status
|
||||
}
|
||||
|
||||
// 按优先级获取 video_url:VideoURL -> Data.VideoURL -> Content.VideoURL
|
||||
videoURL := result.VideoURL
|
||||
if videoURL == "" && result.Data.VideoURL != "" {
|
||||
videoURL = result.Data.VideoURL
|
||||
}
|
||||
if videoURL == "" && result.Content.VideoURL != "" {
|
||||
videoURL = result.Content.VideoURL
|
||||
}
|
||||
|
||||
fmt.Printf("[Chatfire] Parsed result - TaskID: %s, Status: %s, VideoURL: %s\n", responseTaskID, status, videoURL)
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: responseTaskID,
|
||||
Status: status,
|
||||
Completed: status == "completed" || status == "succeeded",
|
||||
}
|
||||
|
||||
if errMsg := getErrorMessage(result.Error); errMsg != "" {
|
||||
videoResult.Error = errMsg
|
||||
}
|
||||
|
||||
if videoURL != "" {
|
||||
videoResult.VideoURL = videoURL
|
||||
videoResult.Completed = true
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
192
pkg/video/minimax_client.go
Normal file
192
pkg/video/minimax_client.go
Normal file
@@ -0,0 +1,192 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MinimaxClient Minimax视频生成客户端
|
||||
type MinimaxClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type MinimaxSubjectReference struct {
|
||||
Type string `json:"type"`
|
||||
Image []string `json:"image"`
|
||||
}
|
||||
|
||||
type MinimaxRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
FirstFrameImage string `json:"first_frame_image,omitempty"`
|
||||
LastFrameImage string `json:"last_frame_image,omitempty"`
|
||||
SubjectReference []MinimaxSubjectReference `json:"subject_reference,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Resolution string `json:"resolution,omitempty"`
|
||||
}
|
||||
|
||||
type MinimaxResponse struct {
|
||||
TaskID string `json:"task_id"`
|
||||
Status string `json:"status"`
|
||||
BaseResp struct {
|
||||
StatusCode int `json:"status_code"`
|
||||
StatusMsg string `json:"status_msg"`
|
||||
} `json:"base_resp"`
|
||||
Video struct {
|
||||
URL string `json:"url"`
|
||||
Duration int `json:"duration"`
|
||||
} `json:"video"`
|
||||
Error struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func NewMinimaxClient(baseURL, apiKey, model string) *MinimaxClient {
|
||||
return &MinimaxClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 300 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateVideo 生成视频(支持首尾帧和主体参考)
|
||||
func (c *MinimaxClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 6,
|
||||
Resolution: "1080P",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
reqBody := MinimaxRequest{
|
||||
Prompt: prompt,
|
||||
Model: model,
|
||||
Duration: options.Duration,
|
||||
}
|
||||
|
||||
// 设置分辨率
|
||||
if options.Resolution != "" {
|
||||
reqBody.Resolution = options.Resolution
|
||||
}
|
||||
|
||||
// 如果有首帧图片(从imageURL或FirstFrameURL)
|
||||
if options.FirstFrameURL != "" {
|
||||
reqBody.FirstFrameImage = options.FirstFrameURL
|
||||
} else if imageURL != "" {
|
||||
reqBody.FirstFrameImage = imageURL
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + "/v1/video_generation"
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result MinimaxResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.Error.Message != "" {
|
||||
return nil, fmt.Errorf("minimax error: %s", result.Error.Message)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.TaskID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
Duration: result.Video.Duration,
|
||||
}
|
||||
|
||||
if result.Video.URL != "" {
|
||||
videoResult.VideoURL = result.Video.URL
|
||||
videoResult.Completed = true
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *MinimaxClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
endpoint := c.BaseURL + "/v1/video_generation/" + taskID
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
var result MinimaxResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.TaskID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
Duration: result.Video.Duration,
|
||||
}
|
||||
|
||||
if result.Error.Message != "" {
|
||||
videoResult.Error = result.Error.Message
|
||||
}
|
||||
|
||||
if result.Video.URL != "" {
|
||||
videoResult.VideoURL = result.Video.URL
|
||||
videoResult.Completed = true
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
178
pkg/video/openai_sora_client.go
Normal file
178
pkg/video/openai_sora_client.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type OpenAISoraClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type OpenAISoraResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Model string `json:"model"`
|
||||
Status string `json:"status"`
|
||||
Progress int `json:"progress"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
CompletedAt int64 `json:"completed_at"`
|
||||
Size string `json:"size"`
|
||||
Seconds string `json:"seconds"`
|
||||
Quality string `json:"quality"`
|
||||
VideoURL string `json:"video_url"` // 直接的video_url字段
|
||||
Video struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"video"` // 嵌套的video.url字段(兼容)
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
func NewOpenAISoraClient(baseURL, apiKey, model string) *OpenAISoraClient {
|
||||
return &OpenAISoraClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 300 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *OpenAISoraClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 4,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
body := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(body)
|
||||
|
||||
writer.WriteField("model", model)
|
||||
writer.WriteField("prompt", prompt)
|
||||
|
||||
if imageURL != "" {
|
||||
writer.WriteField("input_reference", imageURL)
|
||||
}
|
||||
|
||||
if options.Duration > 0 {
|
||||
writer.WriteField("seconds", fmt.Sprintf("%d", options.Duration))
|
||||
}
|
||||
|
||||
if options.Resolution != "" {
|
||||
writer.WriteField("size", options.Resolution)
|
||||
}
|
||||
|
||||
writer.Close()
|
||||
|
||||
endpoint := c.BaseURL + "/videos"
|
||||
req, err := http.NewRequest("POST", endpoint, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
var result OpenAISoraResponse
|
||||
if err := json.Unmarshal(respBody, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.Error.Message != "" {
|
||||
return nil, fmt.Errorf("openai error: %s", result.Error.Message)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
}
|
||||
|
||||
// 优先使用video_url字段,兼容video.url嵌套结构
|
||||
if result.VideoURL != "" {
|
||||
videoResult.VideoURL = result.VideoURL
|
||||
} else if result.Video.URL != "" {
|
||||
videoResult.VideoURL = result.Video.URL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *OpenAISoraClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
endpoint := c.BaseURL + "/videos/" + taskID
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
var result OpenAISoraResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
}
|
||||
|
||||
if result.Error.Message != "" {
|
||||
videoResult.Error = result.Error.Message
|
||||
}
|
||||
|
||||
// 优先使用video_url字段,兼容video.url嵌套结构
|
||||
if result.VideoURL != "" {
|
||||
videoResult.VideoURL = result.VideoURL
|
||||
} else if result.Video.URL != "" {
|
||||
videoResult.VideoURL = result.Video.URL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
427
pkg/video/video_client.go
Normal file
427
pkg/video/video_client.go
Normal file
@@ -0,0 +1,427 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type VideoClient interface {
|
||||
GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error)
|
||||
GetTaskStatus(taskID string) (*VideoResult, error)
|
||||
}
|
||||
|
||||
type VideoResult struct {
|
||||
TaskID string
|
||||
Status string
|
||||
VideoURL string
|
||||
ThumbnailURL string
|
||||
Duration int
|
||||
Width int
|
||||
Height int
|
||||
Error string
|
||||
Completed bool
|
||||
}
|
||||
|
||||
type VideoOptions struct {
|
||||
Model string
|
||||
Duration int
|
||||
FPS int
|
||||
Resolution string
|
||||
AspectRatio string
|
||||
Style string
|
||||
MotionLevel int
|
||||
CameraMotion string
|
||||
Seed int64
|
||||
FirstFrameURL string
|
||||
LastFrameURL string
|
||||
ReferenceImageURLs []string
|
||||
}
|
||||
|
||||
type VideoOption func(*VideoOptions)
|
||||
|
||||
func WithModel(model string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.Model = model
|
||||
}
|
||||
}
|
||||
|
||||
func WithDuration(duration int) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.Duration = duration
|
||||
}
|
||||
}
|
||||
|
||||
func WithFPS(fps int) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.FPS = fps
|
||||
}
|
||||
}
|
||||
|
||||
func WithResolution(resolution string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.Resolution = resolution
|
||||
}
|
||||
}
|
||||
|
||||
func WithAspectRatio(ratio string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.AspectRatio = ratio
|
||||
}
|
||||
}
|
||||
|
||||
func WithStyle(style string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.Style = style
|
||||
}
|
||||
}
|
||||
|
||||
func WithMotionLevel(level int) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.MotionLevel = level
|
||||
}
|
||||
}
|
||||
|
||||
func WithCameraMotion(motion string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.CameraMotion = motion
|
||||
}
|
||||
}
|
||||
|
||||
func WithSeed(seed int64) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.Seed = seed
|
||||
}
|
||||
}
|
||||
|
||||
func WithFirstFrame(url string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.FirstFrameURL = url
|
||||
}
|
||||
}
|
||||
|
||||
func WithLastFrame(url string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.LastFrameURL = url
|
||||
}
|
||||
}
|
||||
|
||||
func WithReferenceImages(urls []string) VideoOption {
|
||||
return func(o *VideoOptions) {
|
||||
o.ReferenceImageURLs = urls
|
||||
}
|
||||
}
|
||||
|
||||
type RunwayClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type RunwayRequest struct {
|
||||
Model string `json:"model"`
|
||||
PromptImage string `json:"prompt_image"`
|
||||
PromptText string `json:"prompt_text"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
type RunwayResponse struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Output struct {
|
||||
URL string `json:"url"`
|
||||
} `json:"output"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func NewRunwayClient(baseURL, apiKey, model string) *RunwayClient {
|
||||
return &RunwayClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 180 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RunwayClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 5,
|
||||
AspectRatio: "16:9",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
reqBody := RunwayRequest{
|
||||
Model: model,
|
||||
PromptImage: imageURL,
|
||||
PromptText: prompt,
|
||||
Duration: options.Duration,
|
||||
AspectRatio: options.AspectRatio,
|
||||
Seed: options.Seed,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + "/v1/video/generate"
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result RunwayResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("runway error: %s", result.Error)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "succeeded",
|
||||
}
|
||||
|
||||
if result.Output.URL != "" {
|
||||
videoResult.VideoURL = result.Output.URL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *RunwayClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
endpoint := c.BaseURL + "/v1/video/status/" + taskID
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
var result RunwayResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "succeeded",
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
videoResult.Error = result.Error
|
||||
}
|
||||
|
||||
if result.Output.URL != "" {
|
||||
videoResult.VideoURL = result.Output.URL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
type PikaClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type PikaRequest struct {
|
||||
Model string `json:"model"`
|
||||
Image string `json:"image"`
|
||||
Prompt string `json:"prompt"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
AspectRatio string `json:"aspect_ratio,omitempty"`
|
||||
Motion int `json:"motion,omitempty"`
|
||||
CameraMotion string `json:"camera_motion,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
}
|
||||
|
||||
type PikaResponse struct {
|
||||
JobID string `json:"job_id"`
|
||||
Status string `json:"status"`
|
||||
Result struct {
|
||||
VideoURL string `json:"video_url"`
|
||||
} `json:"result"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func NewPikaClient(baseURL, apiKey, model string) *PikaClient {
|
||||
return &PikaClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 180 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *PikaClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 3,
|
||||
AspectRatio: "16:9",
|
||||
MotionLevel: 50,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
reqBody := PikaRequest{
|
||||
Model: model,
|
||||
Image: imageURL,
|
||||
Prompt: prompt,
|
||||
Duration: options.Duration,
|
||||
AspectRatio: options.AspectRatio,
|
||||
Motion: options.MotionLevel,
|
||||
CameraMotion: options.CameraMotion,
|
||||
Seed: options.Seed,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + "/v1/video/generate"
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result PikaResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("pika error: %s", result.Error)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.JobID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
}
|
||||
|
||||
if result.Result.VideoURL != "" {
|
||||
videoResult.VideoURL = result.Result.VideoURL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *PikaClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
endpoint := c.BaseURL + "/v1/video/status/" + taskID
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
var result PikaResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.JobID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed",
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
videoResult.Error = result.Error
|
||||
}
|
||||
|
||||
if result.Result.VideoURL != "" {
|
||||
videoResult.VideoURL = result.Result.VideoURL
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
290
pkg/video/volces_ark_client.go
Normal file
290
pkg/video/volces_ark_client.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package video
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// VolcesArkClient 火山引擎ARK视频生成客户端
|
||||
type VolcesArkClient struct {
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Model string
|
||||
Endpoint string
|
||||
QueryEndpoint string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type VolcesArkContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text,omitempty"`
|
||||
ImageURL map[string]interface{} `json:"image_url,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
}
|
||||
|
||||
type VolcesArkRequest struct {
|
||||
Model string `json:"model"`
|
||||
Content []VolcesArkContent `json:"content"`
|
||||
GenerateAudio bool `json:"generate_audio,omitempty"`
|
||||
}
|
||||
|
||||
type VolcesArkResponse struct {
|
||||
ID string `json:"id"`
|
||||
Model string `json:"model"`
|
||||
Status string `json:"status"`
|
||||
Content struct {
|
||||
VideoURL string `json:"video_url"`
|
||||
} `json:"content"`
|
||||
Usage struct {
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
CreatedAt int64 `json:"created_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
Seed int `json:"seed"`
|
||||
Resolution string `json:"resolution"`
|
||||
Ratio string `json:"ratio"`
|
||||
Duration int `json:"duration"`
|
||||
FramesPerSecond int `json:"framespersecond"`
|
||||
ServiceTier string `json:"service_tier"`
|
||||
ExecutionExpiresAfter int `json:"execution_expires_after"`
|
||||
GenerateAudio bool `json:"generate_audio"`
|
||||
Error interface{} `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
func NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *VolcesArkClient {
|
||||
if endpoint == "" {
|
||||
endpoint = "/api/v3/contents/generations/tasks"
|
||||
}
|
||||
if queryEndpoint == "" {
|
||||
queryEndpoint = endpoint
|
||||
}
|
||||
return &VolcesArkClient{
|
||||
BaseURL: baseURL,
|
||||
APIKey: apiKey,
|
||||
Model: model,
|
||||
Endpoint: endpoint,
|
||||
QueryEndpoint: queryEndpoint,
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 300 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateVideo 生成视频(支持首帧、首尾帧、参考图等多种模式)
|
||||
func (c *VolcesArkClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
|
||||
options := &VideoOptions{
|
||||
Duration: 5,
|
||||
AspectRatio: "adaptive",
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
model := c.Model
|
||||
if options.Model != "" {
|
||||
model = options.Model
|
||||
}
|
||||
|
||||
// 构建prompt文本(包含duration和ratio参数)
|
||||
promptText := prompt
|
||||
if options.AspectRatio != "" {
|
||||
promptText += fmt.Sprintf(" --ratio %s", options.AspectRatio)
|
||||
}
|
||||
if options.Duration > 0 {
|
||||
promptText += fmt.Sprintf(" --dur %d", options.Duration)
|
||||
}
|
||||
|
||||
content := []VolcesArkContent{
|
||||
{
|
||||
Type: "text",
|
||||
Text: promptText,
|
||||
},
|
||||
}
|
||||
|
||||
// 处理不同的图片模式
|
||||
// 1. 组图模式(多个reference_image)
|
||||
if len(options.ReferenceImageURLs) > 0 {
|
||||
for _, refURL := range options.ReferenceImageURLs {
|
||||
content = append(content, VolcesArkContent{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": refURL,
|
||||
},
|
||||
Role: "reference_image",
|
||||
})
|
||||
}
|
||||
} else if options.FirstFrameURL != "" && options.LastFrameURL != "" {
|
||||
// 2. 首尾帧模式
|
||||
content = append(content, VolcesArkContent{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": options.FirstFrameURL,
|
||||
},
|
||||
Role: "first_frame",
|
||||
})
|
||||
content = append(content, VolcesArkContent{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": options.LastFrameURL,
|
||||
},
|
||||
Role: "last_frame",
|
||||
})
|
||||
} else if imageURL != "" {
|
||||
// 3. 单图模式(默认)
|
||||
content = append(content, VolcesArkContent{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": imageURL,
|
||||
},
|
||||
// 单图模式不需要role
|
||||
})
|
||||
} else if options.FirstFrameURL != "" {
|
||||
// 4. 只有首帧
|
||||
content = append(content, VolcesArkContent{
|
||||
Type: "image_url",
|
||||
ImageURL: map[string]interface{}{
|
||||
"url": options.FirstFrameURL,
|
||||
},
|
||||
Role: "first_frame",
|
||||
})
|
||||
}
|
||||
|
||||
// 只有 seedance-1-5-pro 模型支持 generate_audio 参数
|
||||
generateAudio := false
|
||||
if strings.Contains(strings.ToLower(model), "seedance-1-5-pro") {
|
||||
generateAudio = true
|
||||
}
|
||||
|
||||
reqBody := VolcesArkRequest{
|
||||
Model: model,
|
||||
Content: content,
|
||||
GenerateAudio: generateAudio,
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal request: %w", err)
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + c.Endpoint
|
||||
fmt.Printf("[VolcesARK] Generating video - Endpoint: %s, FullURL: %s, Model: %s\n", c.Endpoint, endpoint, model)
|
||||
fmt.Printf("[VolcesARK] Request body: %s\n", string(jsonData))
|
||||
|
||||
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[VolcesARK] Response status: %d, body: %s\n", resp.StatusCode, string(body))
|
||||
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
|
||||
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var result VolcesArkResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[VolcesARK] Video generation initiated - TaskID: %s, Status: %s\n", result.ID, result.Status)
|
||||
|
||||
if result.Error != nil {
|
||||
errorMsg := fmt.Sprintf("%v", result.Error)
|
||||
return nil, fmt.Errorf("volces error: %s", errorMsg)
|
||||
}
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed" || result.Status == "succeeded",
|
||||
Duration: result.Duration,
|
||||
}
|
||||
|
||||
if result.Content.VideoURL != "" {
|
||||
videoResult.VideoURL = result.Content.VideoURL
|
||||
videoResult.Completed = true
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
|
||||
func (c *VolcesArkClient) GetTaskStatus(taskID string) (*VideoResult, error) {
|
||||
// 替换占位符{taskId}、{task_id}或直接拼接
|
||||
queryPath := c.QueryEndpoint
|
||||
if strings.Contains(queryPath, "{taskId}") {
|
||||
queryPath = strings.ReplaceAll(queryPath, "{taskId}", taskID)
|
||||
} else if strings.Contains(queryPath, "{task_id}") {
|
||||
queryPath = strings.ReplaceAll(queryPath, "{task_id}", taskID)
|
||||
} else {
|
||||
queryPath = queryPath + "/" + taskID
|
||||
}
|
||||
|
||||
endpoint := c.BaseURL + queryPath
|
||||
fmt.Printf("[VolcesARK] Querying task status - TaskID: %s, QueryEndpoint: %s, FullURL: %s\n", taskID, c.QueryEndpoint, endpoint)
|
||||
|
||||
req, err := http.NewRequest("GET", endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Authorization", "Bearer "+c.APIKey)
|
||||
|
||||
resp, err := c.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("send request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read response: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[VolcesARK] Response body: %s\n", string(body))
|
||||
|
||||
var result VolcesArkResponse
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("[VolcesARK] Parsed result - ID: %s, Status: %s, VideoURL: %s\n", result.ID, result.Status, result.Content.VideoURL)
|
||||
|
||||
videoResult := &VideoResult{
|
||||
TaskID: result.ID,
|
||||
Status: result.Status,
|
||||
Completed: result.Status == "completed" || result.Status == "succeeded",
|
||||
Duration: result.Duration,
|
||||
}
|
||||
|
||||
if result.Error != nil {
|
||||
videoResult.Error = fmt.Sprintf("%v", result.Error)
|
||||
}
|
||||
|
||||
if result.Content.VideoURL != "" {
|
||||
videoResult.VideoURL = result.Content.VideoURL
|
||||
videoResult.Completed = true
|
||||
}
|
||||
|
||||
return videoResult, nil
|
||||
}
|
||||
Reference in New Issue
Block a user