From 1fc8219ee4932f75f3cbfdf0e80c6ac29bdeb22f Mon Sep 17 00:00:00 2001 From: cwz_eikoh Date: Mon, 19 Jan 2026 18:06:05 +0800 Subject: [PATCH] feature(llm): adapt frontend (#23906) * feature(llm): add llm-list details, add autostart for llm-save-instant-model * fix(llm): adjust some interfaces * fix: name-dup problem when create llm * fix: install instant-model by id rather than modelID * fix(llm): add llm_id for mcp-agent * feature(llm): move network from sku to instance * feature(llm): add LLMType for llm-image * feature(llm): add gpuMemoryRequired & ollama-registry yaml * feature(llm): add url-get interface * feature(llm): support mcp in mcp-agent-chat * fix(llm): abstract ollama registry --- cmd/climc/shell/llm/llm.go | 1 + cmd/climc/shell/llm/mcp_agent.go | 96 +++---- pkg/apis/llm/image.go | 27 +- pkg/apis/llm/instantmodel.go | 9 + pkg/apis/llm/llm.go | 68 ++++- pkg/apis/llm/llm_const.go | 23 +- pkg/apis/llm/llm_instant_model.go | 6 +- pkg/apis/llm/mcp_agent.go | 25 +- pkg/apis/llm/ollama_registry.go | 96 +++++++ pkg/apis/llm/sku.go | 11 +- pkg/llm/drivers/llm_client/ollama.go | 34 +-- pkg/llm/drivers/llm_container/ollama.go | 10 +- pkg/llm/models/image.go | 19 ++ pkg/llm/models/instantmodel.go | 90 +++++-- pkg/llm/models/llm.go | 236 +++++++++++++++++- pkg/llm/models/llm_base.go | 41 ++- pkg/llm/models/llm_base_pod.go | 13 +- pkg/llm/models/llm_instant_model_sync.go | 51 +++- pkg/llm/models/llm_pod.go | 2 +- pkg/llm/models/llm_save_instant_model.go | 24 +- pkg/llm/models/llm_sku.go | 108 +++++--- pkg/llm/models/mcp_agent.go | 234 +++++++++-------- pkg/llm/models/sku.go | 37 +-- pkg/llm/service/handler.go | 11 + .../llm/llm_start_save_model_image_task.go | 13 +- pkg/mcclient/options/llm/image.go | 18 +- pkg/mcclient/options/llm/llm.go | 54 ++-- pkg/mcclient/options/llm/llm_sku_base.go | 10 +- pkg/mcclient/options/llm/mcp_agent.go | 17 +- scripts/sync_dify_images.sh | 2 +- 30 files changed, 972 insertions(+), 414 deletions(-) create mode 100644 pkg/apis/llm/ollama_registry.go diff --git a/cmd/climc/shell/llm/llm.go b/cmd/climc/shell/llm/llm.go index 71ac71efc9..32b6fa854b 100644 --- a/cmd/climc/shell/llm/llm.go +++ b/cmd/climc/shell/llm/llm.go @@ -17,6 +17,7 @@ func init() { cmd.BatchPerform("stop", new(options.LLMStopOptions)) cmd.BatchPerform("start", new(options.LLMStartOptions)) cmd.Get("probed-models", new(options.LLMIdOptions)) + cmd.Get("url", new(options.LLMIdOptions)) cmd.Perform("save-instant-model", new(options.LLMSaveInstantModelOptions)) cmd.Perform("quick-models", new(options.LLMQuickModelsOptions)) } diff --git a/cmd/climc/shell/llm/mcp_agent.go b/cmd/climc/shell/llm/mcp_agent.go index 9fc57da40a..bb32b49001 100644 --- a/cmd/climc/shell/llm/mcp_agent.go +++ b/cmd/climc/shell/llm/mcp_agent.go @@ -1,9 +1,11 @@ package llm import ( + "bufio" "fmt" "io" "net/url" + "strings" "yunion.io/x/onecloud/cmd/climc/shell" "yunion.io/x/onecloud/pkg/mcclient" @@ -25,47 +27,55 @@ func init() { cmd.Get("tool-request", new(options.MCPAgentToolRequestOptions)) // cmd.Get("chat-test", new(options.MCPAgentChatTestOptions)) cmd.Get("request", new(options.MCPAgentMCPAgentRequestOptions)) - shell.R(&options.MCPAgentChatTestOptions{}, "mcp-agent-chat", "Chat with MCP Agent (Stream)", func(s *mcclient.ClientSession, args *options.MCPAgentChatTestOptions) error { - id, err := modules.MCPAgent.GetId(s, args.ID, nil) - if err != nil { - return err - } - - path := fmt.Sprintf("/mcp_agents/%s/chat-stream?message=%s", id, url.QueryEscape(args.Message)) - - resp, err := s.RawVersionRequest( - modules.MCPAgent.ServiceType(), - modules.MCPAgent.EndpointType(), - "GET", - path, - nil, - nil, - ) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != 200 { - // Read error body - body, _ := io.ReadAll(resp.Body) - return fmt.Errorf("Error: %s %s", resp.Status, string(body)) - } - - buffer := make([]byte, 1024) - for { - n, err := resp.Body.Read(buffer) - if n > 0 { - fmt.Print(string(buffer[:n])) - } - if err != nil { - if err == io.EOF { - break - } - return err - } - } - fmt.Println() - return nil - }) + shell.R(&options.MCPAgentMCPAgentRequestOptions{}, "mcp-agent-chat", "Chat with MCP Agent (Stream)", chatStream) +} + +func chatStream(s *mcclient.ClientSession, args *options.MCPAgentMCPAgentRequestOptions) error { + id, err := modules.MCPAgent.GetId(s, args.ID, nil) + if err != nil { + return err + } + + path := fmt.Sprintf("/mcp_agents/%s/chat-stream?message=%s", id, url.QueryEscape(args.Message)) + + resp, err := s.RawVersionRequest( + modules.MCPAgent.ServiceType(), + modules.MCPAgent.EndpointType(), + "GET", + path, + nil, + nil, + ) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + // Read error body + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("Error: %s %s", resp.Status, string(body)) + } + + scanner := bufio.NewScanner(resp.Body) + var eventData []string + for scanner.Scan() { + line := scanner.Text() + if line == "" { + if len(eventData) > 0 { + fmt.Print(strings.Join(eventData, "\n")) + eventData = nil + } + continue + } + if after, found := strings.CutPrefix(line, "data: "); found { + eventData = append(eventData, after) + } + } + + if err := scanner.Err(); err != nil { + return err + } + fmt.Println() + return nil } diff --git a/pkg/apis/llm/image.go b/pkg/apis/llm/image.go index f84c801924..f5bda4cd0a 100644 --- a/pkg/apis/llm/image.go +++ b/pkg/apis/llm/image.go @@ -1,12 +1,35 @@ package llm -import "yunion.io/x/onecloud/pkg/apis" +import ( + "yunion.io/x/pkg/util/sets" + + "yunion.io/x/onecloud/pkg/apis" +) + +type LLMImageType string + +const ( + LLM_IMAGE_TYPE_OLLAMA LLMImageType = "ollama" + LLM_IMAGE_TYPE_DIFY LLMImageType = "dify" +) + +var ( + LLM_IMAGE_TYPES = sets.NewString( + string(LLM_IMAGE_TYPE_OLLAMA), + string(LLM_IMAGE_TYPE_DIFY), + ) +) + +func IsLLMImageType(t string) bool { + return LLM_IMAGE_TYPES.Has(t) +} type LLMImageListInput struct { apis.SharableVirtualResourceListInput ImageLabel string `json:"image_label"` ImageName string `json:"image_name"` + LLMType string `json:"llm_type"` } type LLMImageCreateInput struct { @@ -15,6 +38,7 @@ type LLMImageCreateInput struct { ImageName string `json:"image_name"` ImageLabel string `json:"image_label"` CredentialId string `json:"credential_id"` + LLMType string `json:"llm_type"` } type LLMImageUpdateInput struct { @@ -23,4 +47,5 @@ type LLMImageUpdateInput struct { ImageName *string `json:"image_name,omitempty"` ImageLabel *string `json:"image_label,omitempty"` CredentialId *string `json:"credential_id,omitempty"` + LLMType *string `json:"llm_type,omitempty"` } diff --git a/pkg/apis/llm/instantmodel.go b/pkg/apis/llm/instantmodel.go index 1d3f46f94c..697c339388 100644 --- a/pkg/apis/llm/instantmodel.go +++ b/pkg/apis/llm/instantmodel.go @@ -62,6 +62,15 @@ type InstantModelDetails struct { CachedCount int `json:"cached_count"` IconBase64 string `json:"icon_base64"` + + MountedByLLMs []MountedByLLMInfo `json:"mounted_by_llms"` + + GPUMemoryRequired int64 `json:"gpu_memory_required"` +} + +type MountedByLLMInfo struct { + LlmId string `json:"llm_id"` + LlmName string `json:"llm_name"` } type InstantModelSyncstatusInput struct { diff --git a/pkg/apis/llm/llm.go b/pkg/apis/llm/llm.go index 84db10a07f..db4901c9e2 100644 --- a/pkg/apis/llm/llm.go +++ b/pkg/apis/llm/llm.go @@ -1,6 +1,8 @@ package llm import ( + "time" + "yunion.io/x/onecloud/pkg/apis" "yunion.io/x/onecloud/pkg/cloudcommon/db/taskman" ) @@ -9,14 +11,65 @@ const ( SERVICE_TYPE = "llm" ) +type LLMBaseListDetails struct { + apis.VirtualResourceDetails + + // AccessInfo []AccessInfoListOutput + Volume Volume `json:"volume"` + + LLMImage string `json:"llm_image"` + LLMImageLable string `json:"llm_image_lable"` + LLMImageName string `json:"llm_image_name"` + + VcpuCount int `json:"vcpu_count"` + VmemSizeMb int `json:"vmem_size_mb"` + Devices *Devices `json:"devices"` + + NetworkType string `json:"network_type"` + NetworkId string `json:"network_id"` + Network string `json:"network"` + + EffectBandwidthMbps int `json:"effect_bandwidth_mbps"` + StartTime time.Time `json:"start_time"` + + LLMStatus string `json:"llm_status"` + + Server string `json:"server"` + + HostInfo + + Zone string `json:"zone"` + ZoneId string `json:"zone_id"` + + AdbPublic string `json:"adb_public"` + AdbAccess string `json:"adb_access"` +} + +type MountedModelInfo struct { + FullName string `json:"fullname"` // 模型全名,如: qwen3:8b + Id string `json:"id"` // 模型ID,如: 500a1f067a9f +} + +type LLMListDetails struct { + LLMBaseListDetails + + LLMSku string + + MountedModels []MountedModelInfo +} + type LLMBaseCreateInput struct { apis.VirtualResourceCreateInput - PreferHost string `json:"prefer_host"` - AutoStart bool `json:"auto_start"` - BandwidthMB int `json:"bandwidth_mb"` - DebugMode bool `json:"debug_mode"` - RootfsUnlimit bool `json:"rootfs_unlimit"` + PreferHost string `json:"prefer_host"` + AutoStart bool `json:"auto_start"` + + NetworkType string `json:"network_type"` + NetworkId string `json:"network_id"` + + BandwidthMB int `json:"bandwidth_mb"` + DebugMode bool `json:"debug_mode"` + RootfsUnlimit bool `json:"rootfs_unlimit"` } type LLMCreateInput struct { @@ -33,6 +86,9 @@ type LLMBaseListInput struct { Host string `json:"host"` Status []string `json:"status"` + NetworkType string `json:"network_type"` + NetworkId string `json:"network_id"` + NoVolume *bool `json:"no_volume"` ListenPort int `json:"listen_port"` PublicIp string `json:"public_ip"` @@ -56,6 +112,8 @@ type ModelInfo struct { DisplayName string `json:"display_name"` // 秒装模型 tag,如: 7b Tag string `json:"tag"` + // 秒装模型 LLM 类型 + LlmType string `json:"llm_type"` } type LLMPerformQuickModelsInput struct { diff --git a/pkg/apis/llm/llm_const.go b/pkg/apis/llm/llm_const.go index 48f68d3870..9d629a13f3 100644 --- a/pkg/apis/llm/llm_const.go +++ b/pkg/apis/llm/llm_const.go @@ -6,7 +6,7 @@ const ( const ( /* 未知 */ - LLM_STATUS_UNKOWN = "unkown" + LLM_STATUS_UNKNOWN = "unknown" /* 创建失败 */ LLM_STATUS_CREATE_FAIL = "create_fail" @@ -36,21 +36,12 @@ const ( /* 删除 */ LLM_STATUS_DELETED = "deleted" - LLM_STATUS_CREATING_POD = "creating_pod" - LLM_STATUS_CREAT_POD_FAILED = "creat_pod_failed" - LLM_STATUS_PULLING_MODEL = "pulling_model" - LLM_STATUS_GET_MANIFESTS_FAILED = "get_manifests_failed" - LLM_STATUS_DOWNLOADING_BLOBS = "downloading_blobs" - LLM_STATUS_DOWNLOADING_BLOBS_FAILED = "downloading_blobs_failed" - LLM_STATUS_FETCHING_GGUF_FILE = "fetching_gguf_file" - LLM_STATUS_FETCH_GGUF_FILE_FAILED = "fetch_gguf_failed" - LLM_STATUS_CREATING_GGUF_MODEL = "creating_gguf_model" - LLM_STATUS_CREATE_GGUF_MODEL_FAILED = "create_gguf_model_failed" - LLM_STATUS_PULLED_MODEL = "pulled_model" - LLM_STATUS_PULL_MODEL_FAILED = "pull_model_failed" - LLM_STATUS_START_DELETE = "start_delete" - LLM_STATUS_DELETING = "deleting" - LLM_STATUS_DELETE_FAILED = "delete_fail" + LLM_LLM_STATUS_NO_SERVER = "no_server" + LLM_LLM_STATUS_NO_CONTAINER = "no_container" + + LLM_STATUS_START_DELETE = "start_delete" + LLM_STATUS_DELETING = "deleting" + LLM_STATUS_DELETE_FAILED = "delete_fail" ) type TQuickModelMethod string diff --git a/pkg/apis/llm/llm_instant_model.go b/pkg/apis/llm/llm_instant_model.go index 447658a91c..08c8759edc 100644 --- a/pkg/apis/llm/llm_instant_model.go +++ b/pkg/apis/llm/llm_instant_model.go @@ -14,10 +14,10 @@ type LLMInternalInstantMdlInfo struct { type LLMSaveInstantModelInput struct { apis.ProjectizedResourceCreateInput - ModelId string `json:"model_id"` - ImageName string `json:"image_name"` + ModelId string `json:"model_id"` + ModelFullName string `json:"model_full_name"` InstantModelId string `json:"instant_model_id"` - // AutoRestart bool `json:"auto_restart"` + AutoRestart bool `json:"auto_restart"` } diff --git a/pkg/apis/llm/mcp_agent.go b/pkg/apis/llm/mcp_agent.go index 3714cc6949..e2dca3d6d1 100644 --- a/pkg/apis/llm/mcp_agent.go +++ b/pkg/apis/llm/mcp_agent.go @@ -20,9 +20,18 @@ const ( - 管理虚拟机(创建、启动、停止、重启、删除、重置密码) - 获取虚拟机监控信息和实时统计数据 +## 重要规则(必须严格遵守) +**如果用户的问题涉及查询、创建、修改或删除云资源,你必须先调用相应的工具,而不是直接回答。** +- 对于需要查询资源的问题(如"列出虚拟机"、"查询状态"等),必须调用工具获取数据后再回答 +- 对于需要操作资源的问题(如"创建"、"启动"、"停止"等),必须调用工具执行操作后再回答 +- 只有在以下情况才可以直接回复: + 1. 用户只是询问一般性问题(如"你能做什么"、"如何使用"等) + 2. 没有合适的工具可以解决用户的问题 + 3. 工具调用失败后需要向用户说明错误原因 + ## 工作流程 1. 理解用户的需求 -2. 选择合适的工具来完成任务 +2. **优先检查是否有合适的工具可以完成任务,如果有则必须调用工具** 3. 分析工具返回的结果 4. 如果需要更多信息,继续调用其他工具 5. 最后用自然语言总结结果给用户 @@ -31,6 +40,7 @@ const ( - 认证信息已由系统自动处理,调用工具时无需提供认证参数 - 如果工具调用失败,尝试分析错误原因并告知用户 - 回复时使用中文,语言简洁明了 +- **不要在没有调用工具的情况下直接回答需要查询或操作资源的问题** ` ) @@ -78,11 +88,8 @@ type MCPAgentUpdateInput struct { type MCPAgentDetails struct { apis.SharableVirtualResourceDetails - LLMUrl string `json:"llm_url"` - LLMDriver string `json:"llm_driver"` - Model string `json:"model"` - ApiKey string `json:"api_key"` - McpServer string `json:"mcp_server"` + LLMId string `json:"llm_id"` + LLMName string `json:"llm_name"` } type LLMToolRequestInput struct { @@ -90,12 +97,8 @@ type LLMToolRequestInput struct { Arguments map[string]interface{} `json:"arguments"` } -type LLMChatTestInput struct { - Message string `json:"message" help:"test message to send to LLM"` -} - type LLMMCPAgentRequestInput struct { - Query string `json:"query" help:"query to send to MCP agent"` + Message string `json:"message" help:"message to send to MCP agent"` } // MCPAgentResponse 表示 Agent 响应 diff --git a/pkg/apis/llm/ollama_registry.go b/pkg/apis/llm/ollama_registry.go new file mode 100644 index 0000000000..f9e107507a --- /dev/null +++ b/pkg/apis/llm/ollama_registry.go @@ -0,0 +1,96 @@ +package llm + +import ( + "yunion.io/x/jsonutils" +) + +type SOllamaTag struct { + Name string `json:"name" yaml:"name"` + ModelSize string `json:"model_size" yaml:"model_size"` + ContextLength string `json:"context_length" yaml:"context_length"` + Capabilities []string `json:"capabilities" yaml:"capabilities"` + IsLatest bool `json:"is_latest,omitempty" yaml:"is_latest,omitempty"` +} + +func (t SOllamaTag) Latest() SOllamaTag { + t.IsLatest = true + return t +} + +type SOllamaModel struct { + Name string `json:"name" yaml:"name"` + Description string `json:"description" yaml:"description"` + Tags []SOllamaTag `json:"tags" yaml:"tags"` +} + +// SOllamaRegistry 顶层结构,用于生成 +// ollama: +// - name: xxx +// ... +type SOllamaRegistry struct { + Ollama []SOllamaModel `json:"ollama" yaml:"ollama"` +} + +func NewOllamaTag(name, size, contextLen string, caps []string) SOllamaTag { + return SOllamaTag{ + Name: name, + ModelSize: size, + ContextLength: contextLen, + Capabilities: caps, + } +} + +func NewOllamaModel(name, desc string, tags ...SOllamaTag) SOllamaModel { + return SOllamaModel{ + Name: name, + Description: desc, + Tags: tags, + } +} + +func NewOllamaRegistry(models ...SOllamaModel) SOllamaRegistry { + return SOllamaRegistry{ + Ollama: models, + } +} + +var ( + CapText = []string{"Text"} + CapVision = []string{"Text", "Image"} +) + +var OllamaRegistry = NewOllamaRegistry( + NewOllamaModel( + "qwen3-vl", + "Qwen3-vl is the most powerful vision-language model in the Qwen model family to date.", + NewOllamaTag("2b", "1.9GB", "256K", CapVision), + NewOllamaTag("4b", "3.3GB", "256K", CapVision), + NewOllamaTag("8b", "6.1GB", "256K", CapVision).Latest(), + NewOllamaTag("30b", "20GB", "256K", CapVision), + NewOllamaTag("32b", "21GB", "256K", CapVision), + ), + NewOllamaModel( + "qwen3", + "Qwen3 is the latest generation of large language models in Qwen series, offering a comprehensive suite of dense and mixture-of-experts (MoE) models.", + NewOllamaTag("0.6b", "523MB", "40K", CapText), + NewOllamaTag("1.7b", "1.4GB", "40K", CapText), + NewOllamaTag("4b", "2.5GB", "256K", CapText), + NewOllamaTag("8b", "5.2GB", "40K", CapText).Latest(), + NewOllamaTag("14b", "9.3GB", "40K", CapText), + NewOllamaTag("30b", "19GB", "256K", CapText), + NewOllamaTag("32b", "20GB", "40K", CapText), + ), + NewOllamaModel( + "qwen2.5-coder", + "The latest series of Code-Specific Qwen models, with significant improvements in code generation, code reasoning, and code fixing.", + NewOllamaTag("latest", "4.7GB", "32K", CapText), + NewOllamaTag("0.5b", "398MB", "32K", CapText), + NewOllamaTag("1.5b", "986MB", "32K", CapText), + NewOllamaTag("3b", "1.9GB", "32K", CapText), + NewOllamaTag("7b", "4.7GB", "32K", CapText).Latest(), + NewOllamaTag("14b", "9.0GB", "32K", CapText), + NewOllamaTag("32b", "20GB", "32K", CapText), + ), +) + +var OLLAMA_REGISTRY_YAML = jsonutils.Marshal(OllamaRegistry).YAMLString() diff --git a/pkg/apis/llm/sku.go b/pkg/apis/llm/sku.go index 08478f4f81..c2b1610d81 100644 --- a/pkg/apis/llm/sku.go +++ b/pkg/apis/llm/sku.go @@ -126,12 +126,9 @@ type MountedAppResourceDetails struct { type LLMSKuBaseCreateInput struct { apis.SharableVirtualResourceCreateInput - Cpu int `json:"cpu"` - Memory int `json:"memory"` - - NetworkType string `json:"network_type"` - NetworkId string `json:"network_id"` - Bandwidth int `json:"bandwidth"` + Cpu int `json:"cpu"` + Memory int `json:"memory"` + Bandwidth int `json:"bandwidth"` Volumes *Volumes `json:"volumes"` PortMappings *PortMappings `json:"port_mappings"` @@ -153,8 +150,6 @@ type LLMSkuBaseUpdateInput struct { StorageType *string `json:"storage_type"` Volumes *Volumes `json:"volumes"` - NetworkType *string `json:"network_type"` - NetworkId *string `json:"network_id"` Bandwidth *int `json:"bandwidth"` PortMappings *PortMappings `json:"port_mappings"` Devices *Devices `json:"devices"` diff --git a/pkg/llm/drivers/llm_client/ollama.go b/pkg/llm/drivers/llm_client/ollama.go index 2b05d9e6c5..db6a652d3f 100644 --- a/pkg/llm/drivers/llm_client/ollama.go +++ b/pkg/llm/drivers/llm_client/ollama.go @@ -40,20 +40,26 @@ func convertMessages(messages interface{}) ([]OllamaChatMessage, error) { } else if msgs, ok := messages.([]models.ILLMChatMessage); ok { ollamaMessages = make([]OllamaChatMessage, len(msgs)) for i, msg := range msgs { - ollamaMessages[i] = OllamaChatMessage{ - Role: msg.GetRole(), - Content: msg.GetContent(), - } - // 转换工具调用 - if toolCalls := msg.GetToolCalls(); len(toolCalls) > 0 { - ollamaMessages[i].ToolCalls = make([]OllamaToolCall, len(toolCalls)) - for j, tc := range toolCalls { - fc := tc.GetFunction() - ollamaMessages[i].ToolCalls[j] = OllamaToolCall{ - Function: OllamaFunctionCall{ - Name: fc.GetName(), - Arguments: fc.GetArguments(), - }, + // 如果 msg 已经是 *OllamaChatMessage,直接解引用使用 + if ollamaMsg, ok := msg.(*OllamaChatMessage); ok { + ollamaMessages[i] = *ollamaMsg + } else { + // 否则通过接口方法获取 + ollamaMessages[i] = OllamaChatMessage{ + Role: msg.GetRole(), + Content: msg.GetContent(), + } + // 转换工具调用 + if toolCalls := msg.GetToolCalls(); len(toolCalls) > 0 { + ollamaMessages[i].ToolCalls = make([]OllamaToolCall, len(toolCalls)) + for j, tc := range toolCalls { + fc := tc.GetFunction() + ollamaMessages[i].ToolCalls[j] = OllamaToolCall{ + Function: OllamaFunctionCall{ + Name: fc.GetName(), + Arguments: fc.GetArguments(), + }, + } } } } diff --git a/pkg/llm/drivers/llm_container/ollama.go b/pkg/llm/drivers/llm_container/ollama.go index 13c52c4b44..9a54102b3a 100644 --- a/pkg/llm/drivers/llm_container/ollama.go +++ b/pkg/llm/drivers/llm_container/ollama.go @@ -597,17 +597,13 @@ func parseModelName(path string) string { } func (o *ollama) GetLLMUrl(ctx context.Context, userCred mcclient.TokenCredential, llm *models.SLLM) (string, error) { - sku, err := llm.GetLLMSku("") - if err != nil { - return "", errors.Wrap(err, "get llm sku") - } // 查询 accessinfo accessInfo := &models.SAccessInfo{} q := models.GetAccessInfoManager().Query().Equals("llm_id", llm.Id) - err = q.First(accessInfo) + err := q.First(accessInfo) if err != nil { if errors.Cause(err) == sql.ErrNoRows { - // 如果没有 accessinfo,使用默认 localhost + // 如果没有 accessinfo,使用对应主机 server, err := llm.GetServer(ctx) if err != nil { return "", errors.Wrap(err, "get server") @@ -624,7 +620,7 @@ func (o *ollama) GetLLMUrl(ctx context.Context, userCred mcclient.TokenCredentia } // 判断网络类型 - networkType := sku.NetworkType + networkType := llm.NetworkType if networkType == string(computeapi.NETWORK_TYPE_GUEST) { // guest 网络:使用 LLM IP + 默认端口 if len(llm.LLMIp) == 0 { diff --git a/pkg/llm/models/image.go b/pkg/llm/models/image.go index d883789df7..0b55d51d4c 100644 --- a/pkg/llm/models/image.go +++ b/pkg/llm/models/image.go @@ -3,6 +3,7 @@ package models import ( "context" "fmt" + "strings" "yunion.io/x/jsonutils" "yunion.io/x/pkg/errors" @@ -11,6 +12,7 @@ import ( identityapi "yunion.io/x/onecloud/pkg/apis/identity" api "yunion.io/x/onecloud/pkg/apis/llm" "yunion.io/x/onecloud/pkg/cloudcommon/db" + "yunion.io/x/onecloud/pkg/httperrors" "yunion.io/x/onecloud/pkg/llm/options" "yunion.io/x/onecloud/pkg/mcclient" "yunion.io/x/onecloud/pkg/mcclient/auth" @@ -50,6 +52,7 @@ type SLLMImage struct { ImageLabel string `width:"64" charset:"utf8" nullable:"false" list:"user" create:"admin_optional" update:"user"` CredentialId string `width:"128" charset:"utf8" nullable:"true" list:"user" create:"admin_optional" update:"user"` + LLMType string `width:"128" charset:"ascii" nullable:"false" list:"user" create:"admin_optional" update:"user"` } func fetchImageCredential(ctx context.Context, userCred mcclient.TokenCredential, cid string) (*identityapi.CredentialDetails, error) { @@ -81,6 +84,12 @@ func (man *SLLMImageManager) ValidateCreateData(ctx context.Context, userCred mc input.CredentialId = cred.Id } + if len(input.LLMType) > 0 { + if !api.IsLLMImageType(input.LLMType) { + return input, errors.Wrap(httperrors.ErrInputParameter, "llm_type must be one of "+strings.Join(api.LLM_IMAGE_TYPES.List(), ",")) + } + } + input.Status = api.STATUS_READY return input, nil } @@ -99,6 +108,13 @@ func (man *SLLMImageManager) ValidateUpdateData(ctx context.Context, userCred mc } input.CredentialId = &cred.Id } + + if nil != input.LLMType && len(*input.LLMType) > 0 { + if !api.IsLLMImageType(*input.LLMType) { + return input, errors.Wrap(httperrors.ErrInputParameter, "llm_type must be one of "+strings.Join(api.LLM_IMAGE_TYPES.List(), ",")) + } + } + return input, nil } @@ -125,6 +141,9 @@ func (man *SLLMImageManager) ListItemFilter( if len(input.ImageName) > 0 { q = q.Equals("image_name", input.ImageName) } + if len(input.LLMType) > 0 { + q = q.Equals("llm_type", input.LLMType) + } return q, nil } diff --git a/pkg/llm/models/instantmodel.go b/pkg/llm/models/instantmodel.go index 02c7c5012c..6ded75c559 100644 --- a/pkg/llm/models/instantmodel.go +++ b/pkg/llm/models/instantmodel.go @@ -151,7 +151,7 @@ func (man *SInstantModelManager) FetchCustomizeColumns( res := make([]apis.InstantModelDetails, len(objs)) imageIds := make([]string, 0) - // mdlNames := make([]string, 0) + mdlIds := make([]string, 0) virows := man.SSharableVirtualResourceBaseManager.FetchCustomizeColumns(ctx, userCred, query, objs, fields, isList) for i := range res { @@ -160,9 +160,9 @@ func (man *SInstantModelManager) FetchCustomizeColumns( if len(instModel.ImageId) > 0 { imageIds = append(imageIds, instModel.ImageId) } - // if len(instModel.ModelName) > 0 { - // mdlNames = append(mdlNames, instModel.ModelName) - // } + if len(instModel.ModelId) > 0 { + mdlIds = append(mdlIds, instModel.ModelId) + } } s := auth.GetSession(ctx, userCred, options.Options.Region) @@ -238,6 +238,46 @@ func (man *SInstantModelManager) FetchCustomizeColumns( } } + + llmInstModelQ := GetLLMInstantModelManager().Query().In("model_id", mdlIds).IsFalse("deleted") + llmInstModels := make([]SLLMInstantModel, 0) + err := db.FetchModelObjects(GetLLMInstantModelManager(), llmInstModelQ, &llmInstModels) + if err != nil { + log.Errorf("fetch llm instant models fail %s", err) + } + + llmIds := make([]string, 0) + for i := range llmInstModels { + if !utils.IsInArray(llmInstModels[i].LlmId, llmIds) { + llmIds = append(llmIds, llmInstModels[i].LlmId) + } + } + + llmMap := make(map[string]SLLM) + if len(llmIds) > 0 { + err = db.FetchModelObjectsByIds(GetLLMManager(), "id", llmIds, &llmMap) + if err != nil { + log.Errorf("FetchModelObjectsByIds LLMManager fail %s", err) + } + } + + modelMountedByMap := make(map[string][]apis.MountedByLLMInfo) + for i := range llmInstModels { + llmInstModel := llmInstModels[i] + llm, ok := llmMap[llmInstModel.LlmId] + if !ok { + continue + } + info := apis.MountedByLLMInfo{ + LlmId: llmInstModel.LlmId, + LlmName: llm.Name, + } + if _, ok := modelMountedByMap[llmInstModel.ModelId]; !ok { + modelMountedByMap[llmInstModel.ModelId] = make([]apis.MountedByLLMInfo, 0) + } + modelMountedByMap[llmInstModel.ModelId] = append(modelMountedByMap[llmInstModel.ModelId], info) + } + for i := range res { instModel := objs[i].(*SInstantModel) if img, ok := imageMap[instModel.ImageId]; ok { @@ -247,6 +287,11 @@ func (man *SInstantModelManager) FetchCustomizeColumns( res[i].CacheCount = status.CacheCount res[i].CachedCount = status.CachedCount } + if mountedBy, ok := modelMountedByMap[instModel.ModelId]; ok { + res[i].MountedByLLMs = mountedBy + } + + res[i].GPUMemoryRequired = instModel.GetEstimatedVramSizeMb() } return res } @@ -369,7 +414,7 @@ func (model *SInstantModel) PostCreate( if err != nil { return } - if input.DoNotImport == nil || !*input.DoNotImport { + if input.ImageId == "" && (input.DoNotImport == nil || !*input.DoNotImport) { model.startImportTask(ctx, userCred, apis.InstantModelImportInput{ LlmType: input.LlmType, ModelName: input.ModelName, @@ -533,15 +578,15 @@ func (model *SInstantModel) PerformEnable( return nil, errors.Wrapf(errors.ErrInvalidStatus, "cannot enable model of status %s", model.Status) } // check duplicate - { - existing, err := GetInstantModelManager().findInstantModel(model.ModelId, model.ModelTag, true) - if err != nil { - return nil, errors.Wrap(err, "findInstantModel") - } - if existing != nil && existing.Id != model.Id { - return nil, errors.Wrapf(errors.ErrDuplicateId, "model of modelId %s tag %s has been enabled", model.ModelId, model.ModelTag) - } - } + // { + // existing, err := GetInstantModelManager().findInstantModel(model.ModelId, model.ModelTag, true) + // if err != nil { + // return nil, errors.Wrap(err, "findInstantModel") + // } + // if existing != nil && existing.Id != model.Id { + // return nil, errors.Wrapf(errors.ErrDuplicateId, "model of modelId %s tag %s has been enabled", model.ModelId, model.ModelTag) + // } + // } _, err := db.Update(model, func() error { model.SEnabledResourceBase.SetEnabled(true) return nil @@ -885,6 +930,18 @@ func (model *SInstantModel) GetActualSizeMb() int32 { return int32(model.Size / 1024 / 1024) } +func (model *SInstantModel) GetEstimatedVramSizeBytes() int64 { + if model.Size <= 0 { + return 0 + } + // 1.0x 基础权重 + 0.15x 动态开销(KV Cache) + 500MB 框架固定开销 + return int64(float64(model.Size)*1.15) + 500*1024*1024 +} + +func (model *SInstantModel) GetEstimatedVramSizeMb() int64 { + return model.GetEstimatedVramSizeBytes() / 1024 / 1024 +} + func (model *SInstantModel) CleanupImportTmpDir(ctx context.Context, userCred mcclient.TokenCredential, tmpDir string) error { // sync image status err := model.syncImageStatus(ctx, userCred) @@ -900,3 +957,8 @@ func (model *SInstantModel) CleanupImportTmpDir(ctx context.Context, userCred mc } return nil } + +// GetOllamaRegistryYAML returns the Ollama registry YAML content +func (man *SInstantModelManager) GetOllamaRegistryYAML() string { + return apis.OLLAMA_REGISTRY_YAML +} diff --git a/pkg/llm/models/llm.go b/pkg/llm/models/llm.go index 65b6aed6b8..f6c11d6329 100644 --- a/pkg/llm/models/llm.go +++ b/pkg/llm/models/llm.go @@ -3,6 +3,7 @@ package models import ( "context" "database/sql" + "fmt" "strings" "yunion.io/x/jsonutils" @@ -17,9 +18,13 @@ import ( "yunion.io/x/onecloud/pkg/cloudcommon/db" "yunion.io/x/onecloud/pkg/cloudcommon/db/taskman" "yunion.io/x/onecloud/pkg/httperrors" + "yunion.io/x/onecloud/pkg/llm/options" llmutils "yunion.io/x/onecloud/pkg/llm/utils" "yunion.io/x/onecloud/pkg/mcclient" + "yunion.io/x/onecloud/pkg/mcclient/auth" "yunion.io/x/onecloud/pkg/mcclient/modules/compute" + computeoptions "yunion.io/x/onecloud/pkg/mcclient/options/compute" + "yunion.io/x/onecloud/pkg/util/stringutils2" ) var llmManager *SLLMManager @@ -112,18 +117,199 @@ func (man *SLLMManager) ListItemFilter(ctx context.Context, q *sqlchemy.SQuery, q = q.Equals("llm_image_id", imgObj.GetId()) } - // if input.Unused != nil { - // instanceQ := GetDesktopInstanceManager().Query().SubQuery() - // if *input.Unused { - // q = q.NotEquals("id", instanceQ.Query(instanceQ.Field("desktop_id")).SubQuery()) - // } else { - // q = q.Join(instanceQ, sqlchemy.Equals(q.Field("id"), instanceQ.Field("desktop_id"))) - // } - // } - return q, nil } +func (man *SLLMManager) FetchCustomizeColumns( + ctx context.Context, + userCred mcclient.TokenCredential, + query jsonutils.JSONObject, + objs []interface{}, + fields stringutils2.SSortedStrings, + isList bool, +) []api.LLMListDetails { + virtRows := man.SVirtualResourceBaseManager.FetchCustomizeColumns(ctx, userCred, query, objs, fields, isList) + llms := []SLLM{} + jsonutils.Update(&llms, objs) + res := make([]api.LLMListDetails, len(objs)) + for i := 0; i < len(res); i++ { + res[i].VirtualResourceDetails = virtRows[i] + } + + ids := make([]string, len(llms)) + skuIds := make([]string, len(llms)) + imgIds := make([]string, len(llms)) + serverIds := []string{} + networkIds := []string{} + for idx, llm := range llms { + ids[idx] = llm.Id + skuIds[idx] = llm.LLMSkuId + imgIds[idx] = llm.LLMImageId + if !utils.IsInArray(llm.SvrId, serverIds) { + serverIds = append(serverIds, llm.SvrId) + } + if len(llm.NetworkId) > 0 { + networkIds = append(networkIds, llm.NetworkId) + } + mountedModelInfo, _ := llm.FetchMountedModelInfo() + res[idx].MountedModels = mountedModelInfo + res[idx].NetworkType = llm.NetworkType + res[idx].NetworkId = llm.NetworkId + } + + // fetch volume + volumeQ := GetVolumeManager().Query().In("llm_Id", ids) + volumes := []SVolume{} + db.FetchModelObjects(GetVolumeManager(), volumeQ, &volumes) + for _, volume := range volumes { + for i, id := range ids { + if id == volume.LLMId { + res[i].Volume = api.Volume{ + Id: volume.Id, + Name: volume.Name, + TemplateId: volume.TemplateId, + StorageType: volume.StorageType, + SizeMB: volume.SizeMB, + } + } + } + } + + // fetch sku + skus := make(map[string]SLLMSku) + err := db.FetchModelObjectsByIds(GetLLMSkuManager(), "id", skuIds, &skus) + if err == nil { + for i := range llms { + if sku, ok := skus[llms[i].LLMSkuId]; ok { + res[i].LLMSku = sku.Name + res[i].VcpuCount = sku.Cpu + res[i].VmemSizeMb = sku.Memory + res[i].Devices = sku.Devices + if llms[i].BandwidthMb != 0 { + res[i].EffectBandwidthMbps = llms[i].BandwidthMb + } else { + res[i].EffectBandwidthMbps = sku.Bandwidth + } + } + } + } else { + log.Errorf("FetchModelObjectsByIds LLMSkuManager fail %s", err) + } + + // fetch image + images := make(map[string]SLLMImage) + err = db.FetchModelObjectsByIds(GetLLMImageManager(), "id", imgIds, &images) + if err == nil { + for i := range llms { + if image, ok := images[llms[i].LLMImageId]; ok { + res[i].LLMImage = image.Name + res[i].LLMImageLable = image.ImageLabel + res[i].LLMImageName = image.ImageName + } + } + } else { + log.Errorf("FetchModelObjectsByIds GetLLMImageManager fail %s", err) + } + + // fetch network + if len(networkIds) > 0 { + networks, err := fetchNetworks(ctx, userCred, networkIds) + if err == nil { + for i, llm := range llms { + if net, ok := networks[llm.NetworkId]; ok { + res[i].Network = net.Name + } + } + } else { + log.Errorf("fail to retrieve network info %s", err) + } + } + + // fetch host + if len(serverIds) > 0 { + // allow query cmp server + serverMap := make(map[string]computeapi.ServerDetails) + s := auth.GetAdminSession(ctx, options.Options.Region) + params := computeoptions.ServerListOptions{} + limit := 1000 + params.Limit = &limit + details := true + params.Details = &details + params.Scope = "maxallowed" + offset := 0 + for offset < len(serverIds) { + lastIdx := offset + limit + if lastIdx > len(serverIds) { + lastIdx = len(serverIds) + } + params.Id = serverIds[offset:lastIdx] + results, err := compute.Servers.List(s, jsonutils.Marshal(params)) + if err != nil { + log.Errorf("query servers fails %s", err) + break + } else { + offset = lastIdx + for i := range results.Data { + guest := computeapi.ServerDetails{} + err := results.Data[i].Unmarshal(&guest) + if err == nil { + serverMap[guest.Id] = guest + } + } + } + } + + for i := range llms { + llmStatus := api.LLM_STATUS_UNKNOWN + llm := llms[i] + if guest, ok := serverMap[llm.SvrId]; ok { + // find guest + if len(guest.Containers) == 0 { + llmStatus = api.LLM_LLM_STATUS_NO_CONTAINER + } else { + llmCtr := guest.Containers[0] + if llmCtr == nil { + llmStatus = api.LLM_LLM_STATUS_NO_CONTAINER + } else { + llmStatus = llmCtr.Status + } + } + + res[i].Server = guest.Name + res[i].StartTime = guest.LastStartAt + res[i].Host = guest.Host + res[i].HostId = guest.HostId + res[i].HostAccessIp = guest.HostAccessIp + res[i].HostEIP = guest.HostEIP + res[i].Zone = guest.Zone + res[i].ZoneId = guest.ZoneId + + adbMappedPort := -1 + // for j := range res[i].AccessInfo { + // res[i].AccessInfo[j].DesktopIp = guest.IPs + // res[i].AccessInfo[j].ServerIp = guest.HostAccessIp + // res[i].AccessInfo[j].PublicIp = guest.HostEIP + // /*if res[i].AccessInfo[j].ListenPort == api.DESKTOP_ADB_PORT { + // adbMappedPort = res[i].AccessInfo[j].AccessPort + // }*/ + // } + + if adbMappedPort >= 0 { + res[i].AdbAccess = fmt.Sprintf("%s:%d", guest.HostAccessIp, adbMappedPort) + if len(res[i].HostEIP) > 0 { + res[i].AdbPublic = fmt.Sprintf("%s:%d", guest.HostEIP, adbMappedPort) + } + } + } else { + llmStatus = api.LLM_LLM_STATUS_NO_SERVER + } + res[i].LLMStatus = llmStatus + } + } + + return res +} + func (lm *SLLMManager) OnCreateComplete(ctx context.Context, items []db.IModel, userCred mcclient.TokenCredential, ownerId mcclient.IIdentityProvider, query jsonutils.JSONObject, data []jsonutils.JSONObject) { parentTaskId, _ := data[0].GetString("parent_task_id") err := runBatchCreateTask(ctx, items, userCred, data, "LLMBatchCreateTask", parentTaskId) @@ -360,3 +546,35 @@ func (llm *SLLM) StartSyncStatusTask(ctx context.Context, userCred mcclient.Toke func (llm *SLLM) GetLLMUrl(ctx context.Context, userCred mcclient.TokenCredential) (string, error) { return llm.GetLLMContainerDriver().GetLLMUrl(ctx, userCred, llm) } + +func (llm *SLLM) GetDetailsUrl(ctx context.Context, userCred mcclient.TokenCredential, query jsonutils.JSONObject) (jsonutils.JSONObject, error) { + accessUrl, err := llm.GetLLMUrl(ctx, userCred) + if err != nil { + return nil, errors.Wrap(err, "GetLLMUrl") + } + output := jsonutils.NewDict() + output.Set("access_url", jsonutils.NewString(accessUrl)) + return output, nil +} + +func fetchNetworks(ctx context.Context, userCred mcclient.TokenCredential, networkIds []string) (map[string]computeapi.NetworkDetails, error) { + s := auth.GetSession(ctx, userCred, "") + params := computeoptions.ServerListOptions{} + params.Id = networkIds + limit := len(networkIds) + params.Limit = &limit + params.Scope = "maxallowed" + results, err := compute.Networks.List(s, jsonutils.Marshal(params)) + if err != nil { + return nil, errors.Wrap(err, "Networks.List") + } + networks := make(map[string]computeapi.NetworkDetails) + for i := range results.Data { + net := computeapi.NetworkDetails{} + err := results.Data[i].Unmarshal(&net) + if err == nil { + networks[net.Id] = net + } + } + return networks, nil +} diff --git a/pkg/llm/models/llm_base.go b/pkg/llm/models/llm_base.go index 3d8d1ecadd..807f653d23 100644 --- a/pkg/llm/models/llm_base.go +++ b/pkg/llm/models/llm_base.go @@ -63,6 +63,9 @@ type SLLMBase struct { DebugMode bool `default:"false" nullable:"false" list:"user" update:"user"` RootfsUnlimit bool `default:"false" nullable:"false" list:"user" update:"user"` + + NetworkType string `charset:"utf8" list:"user" update:"user" create:"optional"` + NetworkId string `charset:"utf8" nullable:"true" list:"user" update:"user" create:"optional"` } func (man *SLLMBaseManager) ValidateCreateData(ctx context.Context, userCred mcclient.TokenCredential, ownerId mcclient.IIdentityProvider, query jsonutils.JSONObject, input api.LLMBaseCreateInput) (api.LLMBaseCreateInput, error) { @@ -94,6 +97,20 @@ func (man *SLLMBaseManager) ValidateCreateData(ctx context.Context, userCred mcc input.PreferHost = hostDetails.Id } + if len(input.NetworkType) > 0 && !api.IsLLMSkuBaseNetworkType(input.NetworkType) { + return input, errors.Wrapf(httperrors.ErrInputParameter, "invalid network type %s", input.NetworkType) + } + + if len(input.NetworkId) > 0 { + s := auth.GetSession(ctx, userCred, "") + netObj, err := compute.Networks.Get(s, input.NetworkId, nil) + if err != nil { + return input, errors.Wrapf(httperrors.ErrInputParameter, "invalid network_id %s", input.NetworkId) + } + input.NetworkId, _ = netObj.GetString("id") + input.NetworkType, _ = netObj.GetString("server_type") + } + return input, nil } @@ -136,6 +153,22 @@ func (man *SLLMBaseManager) ListItemFilter(ctx context.Context, q *sqlchemy.SQue return q, errors.Wrap(err, "SEnabledResourceBaseManager.ListItemFilter") } + if len(input.NetworkType) > 0 { + q = q.Equals("network_type", input.NetworkType) + } + if len(input.NetworkId) > 0 { + s := auth.GetSession(ctx, userCred, "") + netObj, err := compute.Networks.Get(s, input.NetworkId, nil) + if err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, errors.Wrapf(httperrors.ErrResourceNotFound, "network %s not found", input.NetworkId) + } + return nil, errors.Wrap(err, "Networks.Get") + } + netId, _ := netObj.GetString("id") + q = q.Equals("network_id", netId) + } + if len(input.Host) > 0 { serverIds, err := GetServerIdsByHost(ctx, userCred, input.Host) if err != nil { @@ -225,14 +258,6 @@ func (man *SLLMBaseManager) ListItemFilter(ctx context.Context, q *sqlchemy.SQue q = q.In("svr_id", serverIds) } } - // if input.Unused != nil { - // instanceQ := GetDesktopInstanceManager().Query().SubQuery() - // if *input.Unused { - // q = q.NotEquals("id", instanceQ.Query(instanceQ.Field("desktop_id")).SubQuery()) - // } else { - // q = q.Join(instanceQ, sqlchemy.Equals(q.Field("id"), instanceQ.Field("desktop_id"))) - // } - // } return q, nil } diff --git a/pkg/llm/models/llm_base_pod.go b/pkg/llm/models/llm_base_pod.go index 2f6fb69f78..219a986b78 100644 --- a/pkg/llm/models/llm_base_pod.go +++ b/pkg/llm/models/llm_base_pod.go @@ -5,6 +5,7 @@ import ( "fmt" "yunion.io/x/jsonutils" + "yunion.io/x/pkg/util/seclib" "yunion.io/x/onecloud/pkg/apis" computeapi "yunion.io/x/onecloud/pkg/apis/compute" @@ -44,7 +45,7 @@ func GetLLMBasePodCreateInput( data.VcpuCount = skuBase.Cpu data.VmemSize = skuBase.Memory + 1 - data.Name = input.Name + data.Name = input.Name + "-" + seclib.RandomPassword(6) // disks data.Disks = make([]*computeapi.DiskConfig, 0) @@ -110,18 +111,18 @@ func GetLLMBasePodCreateInput( } bandwidth := llmBase.BandwidthMb if bandwidth == 0 { - bandwidth = skuBase.BandwidthMb + bandwidth = skuBase.Bandwidth } network := &computeapi.NetworkConfig{ BwLimit: bandwidth, - NetType: computeapi.TNetworkType(skuBase.NetworkType), + NetType: computeapi.TNetworkType(llmBase.NetworkType), } - if skuBase.NetworkType == string(computeapi.NETWORK_TYPE_HOSTLOCAL) { + if llmBase.NetworkType == string(computeapi.NETWORK_TYPE_HOSTLOCAL) { network.PortMappings = portMappings } - if len(skuBase.NetworkId) > 0 { - network.Network = skuBase.NetworkId + if len(llmBase.NetworkId) > 0 { + network.Network = llmBase.NetworkId } data.Networks = []*computeapi.NetworkConfig{ diff --git a/pkg/llm/models/llm_instant_model_sync.go b/pkg/llm/models/llm_instant_model_sync.go index f4b7ad014a..effb6e6636 100644 --- a/pkg/llm/models/llm_instant_model_sync.go +++ b/pkg/llm/models/llm_instant_model_sync.go @@ -250,12 +250,13 @@ func (llm *SLLM) PerformQuickModels(ctx context.Context, userCred mcclient.Token errs = append(errs, errors.Wrap(err, "FetchByIdOrName")) } } else { - instApp := instModelObj.(*SInstantModel) - input.Models[i].Id = instApp.Id - input.Models[i].ModelId = instApp.ModelId - input.Models[i].Tag = instApp.ModelTag + instMdl := instModelObj.(*SInstantModel) + input.Models[i].Id = instMdl.Id + input.Models[i].ModelId = instMdl.ModelId + input.Models[i].Tag = instMdl.ModelTag + input.Models[i].LlmType = instMdl.LlmType if input.Method == apis.QuickModelInstall { - toInstallSizeGb += float64(instApp.GetActualSizeMb()) * 1024 * 1024 / 1000 / 1000 / 1000 + toInstallSizeGb += float64(instMdl.GetActualSizeMb()) * 1024 * 1024 / 1000 / 1000 / 1000 } } } else { @@ -269,8 +270,12 @@ func (llm *SLLM) PerformQuickModels(ctx context.Context, userCred mcclient.Token input.Models[i].Id = mdl.Id input.Models[i].Tag = mdl.ModelTag input.Models[i].ModelId = mdl.ModelId + input.Models[i].LlmType = mdl.LlmType } } + if !apis.IsLLMContainerType(input.Models[i].LlmType) || apis.LLMContainerType(input.Models[i].LlmType) != llm.GetLLMContainerDriver().GetType() { + errs = append(errs, errors.Wrapf(httperrors.ErrInvalidStatus, "model %s is not of type %s", input.Models[i].ModelId, llm.GetLLMContainerDriver().GetType())) + } } if len(errs) > 0 { return nil, errors.NewAggregate(errs) @@ -353,6 +358,22 @@ func (llm *SLLM) FetchMountedModelFullName() ([]string, error) { return llm.FetchModelsFullName(nil, &boolTrue) } +func (llm *SLLM) FetchMountedModelInfo() ([]apis.MountedModelInfo, error) { + boolTrue := true + models, err := llm.FetchModels(nil, &boolTrue, nil) + if err != nil { + return nil, errors.Wrap(err, "FetchModels") + } + result := make([]apis.MountedModelInfo, len(models)) + for idx, mdl := range models { + result[idx] = apis.MountedModelInfo{ + FullName: mdl.ModelName + ":" + mdl.Tag, + Id: mdl.ModelId, + } + } + return result, nil +} + func (llm *SLLM) RequestUnmountModel(ctx context.Context, userCred mcclient.TokenCredential, input apis.LLMSyncModelTaskInput) ([]string, []*commonapi.ContainerVolumeMountDiskPostOverlay, error) { if input.LLMStatus == apis.LLM_STATUS_RUNNING { err := llm.RefreshInstantModels(ctx, userCred, true) @@ -594,7 +615,7 @@ type mdlFullNameInfo struct { IsMounted bool } -func (llm *SLLM) UpdateMountedModelFullNames(ctx context.Context, mdlinfos []string, isReset bool, imageId string, skuId string) error { +func (llm *SLLM) UpdateMountedModelFullNames(ctx context.Context, userCred mcclient.TokenCredential, mdlinfos []string, isReset bool, imageId string, skuId string) error { mdlFullNameInfos := make(map[string]*mdlFullNameInfo) for i := range mdlinfos { parts := strings.Split(mdlinfos[i], "@") @@ -620,15 +641,19 @@ func (llm *SLLM) UpdateMountedModelFullNames(ctx context.Context, mdlinfos []str } } for i := range sku.MountedModels { - parts := strings.Split(sku.MountedModels[i], "@") - if !isReset && slices.Contains(deletedModelIds, parts[0]) { - // if not reset, and the package is deleted, skip it + instMdl, err := GetInstantModelManager().FetchByIdOrName(ctx, userCred, sku.MountedModels[i]) + if err != nil { + return errors.Wrap(err, "FetchByIdOrName") + } + instantModle := instMdl.(*SInstantModel) + if !isReset && slices.Contains(deletedModelIds, instantModle.ModelId) { + // if not reset, and the model is deleted, skip it continue } - if _, ok := mdlFullNameInfos[parts[0]]; !ok { - mdlFullNameInfos[parts[0]] = &mdlFullNameInfo{ - ModelId: parts[0], - ModelFullName: parts[1], + if _, ok := mdlFullNameInfos[instantModle.ModelId]; !ok { + mdlFullNameInfos[instantModle.ModelId] = &mdlFullNameInfo{ + ModelId: instantModle.ModelId, + ModelFullName: instantModle.ModelName + ":" + instantModle.ModelTag, IsMounted: false, } } diff --git a/pkg/llm/models/llm_pod.go b/pkg/llm/models/llm_pod.go index 1c6fea6e64..1cba439584 100644 --- a/pkg/llm/models/llm_pod.go +++ b/pkg/llm/models/llm_pod.go @@ -26,7 +26,7 @@ func GetLLMPodCreateInput( // generate post overlay info { - err = llm.UpdateMountedModelFullNames(ctx, nil, true, input.LLMImageId, input.LLMSkuId) + err = llm.UpdateMountedModelFullNames(ctx, userCred, nil, true, input.LLMImageId, input.LLMSkuId) if err != nil { return nil, errors.Wrap(err, "UpdateMountedModelFullNames") } diff --git a/pkg/llm/models/llm_save_instant_model.go b/pkg/llm/models/llm_save_instant_model.go index ed457bdf9e..dfde2a9746 100644 --- a/pkg/llm/models/llm_save_instant_model.go +++ b/pkg/llm/models/llm_save_instant_model.go @@ -56,8 +56,8 @@ func (llm *SLLM) PerformSaveInstantModel( return nil, errors.Wrap(err, "detectModelPaths") } - if len(input.ImageName) == 0 { - input.ImageName = fmt.Sprintf("%s-%s", mdlInfo.Name+":"+mdlInfo.Tag, time.Now().Format("060102")) + if len(input.ModelFullName) == 0 { + input.ModelFullName = fmt.Sprintf("%s-%s", mdlInfo.Name+":"+mdlInfo.Tag, time.Now().Format("060102")) } var ownerId mcclient.IIdentityProvider @@ -89,17 +89,25 @@ func (llm *SLLM) PerformSaveInstantModel( input.ProjectId = ownerId.GetProjectId() input.ProjectDomainId = ownerId.GetProjectDomainId() + modelName, modelTag, _ := llm.GetLargeLanguageModelName(input.ModelFullName) + if len(modelName) == 0 { + modelName = mdlInfo.Name + } + if len(modelTag) == 0 { + modelTag = mdlInfo.Tag + } + drv := llm.GetLLMContainerDriver() instantModelCreateInput := api.InstantModelCreateInput{ LlmType: drv.GetType(), ModelId: mdlInfo.ModelId, - ModelName: mdlInfo.Name, - ModelTag: mdlInfo.Tag, + ModelName: modelName, + ModelTag: modelTag, Mounts: mountDirs, } - instantModelCreateInput.Name = input.ImageName - booTrue := true - instantModelCreateInput.DoNotImport = &booTrue + instantModelCreateInput.Name = input.ModelFullName + boolTrue := true + instantModelCreateInput.DoNotImport = &boolTrue log.Debugf("instantModelCreateInput: %s", jsonutils.Marshal(instantModelCreateInput)) instantMdlObj, err := db.DoCreate(GetInstantModelManager(), ctx, userCred, nil, jsonutils.Marshal(instantModelCreateInput), ownerId) @@ -135,7 +143,7 @@ func (llm *SLLM) DoSaveModelImage(ctx context.Context, userCred mcclient.TokenCr } saveImageInput := computeapi.ContainerSaveVolumeMountToImageInput{ - GenerateName: input.ImageName, + GenerateName: input.ModelFullName, Notes: fmt.Sprintf("instance model image for %s(%s)", input.ModelId, instantModel.ModelName+":"+instantModel.ModelTag), Index: 0, Dirs: saveDirs, diff --git a/pkg/llm/models/llm_sku.go b/pkg/llm/models/llm_sku.go index 818cc8be3b..85f750b988 100644 --- a/pkg/llm/models/llm_sku.go +++ b/pkg/llm/models/llm_sku.go @@ -9,11 +9,15 @@ import ( "yunion.io/x/pkg/errors" "yunion.io/x/sqlchemy" + imageapi "yunion.io/x/onecloud/pkg/apis/image" api "yunion.io/x/onecloud/pkg/apis/llm" "yunion.io/x/onecloud/pkg/cloudcommon/db" "yunion.io/x/onecloud/pkg/cloudcommon/validators" "yunion.io/x/onecloud/pkg/httperrors" "yunion.io/x/onecloud/pkg/mcclient" + "yunion.io/x/onecloud/pkg/mcclient/auth" + imagemodules "yunion.io/x/onecloud/pkg/mcclient/modules/image" + mcclientoptions "yunion.io/x/onecloud/pkg/mcclient/options" "yunion.io/x/onecloud/pkg/util/stringutils2" ) @@ -82,38 +86,38 @@ func (manager *SLLMSkuManager) FetchCustomizeColumns( fields stringutils2.SSortedStrings, isList bool, ) []api.LLMSkuDetails { - // skuIds := []string{} + skuIds := []string{} imageIds := []string{} - // templateIds := []string{} + templateIds := []string{} skus := []SLLMSku{} jsonutils.Update(&skus, objs) virows := manager.SSharableVirtualResourceBaseManager.FetchCustomizeColumns(ctx, userCred, query, objs, fields, isList) for _, sku := range skus { - // skuIds = append(skuIds, sku.Id) + skuIds = append(skuIds, sku.Id) imageIds = append(imageIds, sku.LLMImageId) - // if sku.Volumes != nil && len(*sku.Volumes) > 0 && len((*sku.Volumes)[0].TemplateId) > 0 { - // templateIds = append(templateIds, (*sku.Volumes)[0].TemplateId) - // } + if sku.Volumes != nil && len(*sku.Volumes) > 0 && len((*sku.Volumes)[0].TemplateId) > 0 { + templateIds = append(templateIds, (*sku.Volumes)[0].TemplateId) + } } - // q := GetLLMManager().Query().In("llm_model_id", skuIds).GroupBy("llm_model_id") - // q = q.AppendField(q.Field("llm_model_id")) - // q = q.AppendField(sqlchemy.COUNT("llm_capacity")) - // details := []struct { - // LLMModelId string - // LLMCapacity int - // }{} - // q.All(&details) + q := GetLLMManager().Query().In("llm_sku_id", skuIds).GroupBy("llm_sku_id") + q = q.AppendField(q.Field("llm_sku_id")) + q = q.AppendField(sqlchemy.COUNT("llm_capacity")) + details := []struct { + LLMSkuId string + LLMCapacity int + }{} + q.All(&details) res := make([]api.LLMSkuDetails, len(objs)) - for i := range skus { + for i, sku := range skus { res[i].SharableVirtualResourceDetails = virows[i] - // for _, v := range details { - // if v.LLMModelId == sku.Id { - // res[i].LLMCapacity = v.LLMCapacity - // break - // } - // } + for _, v := range details { + if v.LLMSkuId == sku.Id { + res[i].LLMCapacity = v.LLMCapacity + break + } + } } { images := make(map[string]SLLMImage) @@ -127,22 +131,22 @@ func (manager *SLLMSkuManager) FetchCustomizeColumns( } } } else { - log.Errorf("FetchModelObjectsByIds DesktopImageManager fail %s", err) + log.Errorf("FetchModelObjectsByIds LLMImageManager fail %s", err) } } - // if len(templateIds) > 0 { - // templates, err := fetchTemplates(ctx, userCred, templateIds) - // if err == nil { - // for i, sku := range skus { - // if templ, ok := templates[(*sku.Volumes)[0].TemplateId]; ok { - // res[i].Template = templ.Name - // } - // } - // } else { - // log.Errorf("fail to retrive image info %s", err) - // } - // } + if len(templateIds) > 0 { + templates, err := fetchTemplates(ctx, userCred, templateIds) + if err == nil { + for i, sku := range skus { + if templ, ok := templates[(*sku.Volumes)[0].TemplateId]; ok { + res[i].Template = templ.Name + } + } + } else { + log.Errorf("fail to retrive image info %s", err) + } + } return res } @@ -177,6 +181,20 @@ func (sku *SLLMSku) ValidateUpdateData(ctx context.Context, userCred mcclient.To return input, errors.Wrap(err, "validate LLMSkuBaseUpdateInput") } + if input.MountedModels != nil { + for i, mdl := range input.MountedModels { + instMdl, err := GetInstantModelManager().FetchByIdOrName(ctx, userCred, mdl) + if err != nil { + return input, errors.Wrapf(err, "validate mounted model %s", mdl) + } + instantModle := instMdl.(*SInstantModel) + if instantModle.LlmType != sku.LLMType { + return input, errors.Wrapf(httperrors.ErrInvalidStatus, "mounted model %s is not of type %s", mdl, sku.LLMType) + } + input.MountedModels[i] = instantModle.GetId() + } + } + if input.LLMImageId != "" { imgObj, err := validators.ValidateModel(ctx, userCred, GetLLMImageManager(), &input.LLMImageId) if err != nil { @@ -198,3 +216,25 @@ func (sku *SLLMSku) ValidateDeleteCondition(ctx context.Context, info jsonutils. } return nil } + +func fetchTemplates(ctx context.Context, userCred mcclient.TokenCredential, templateIds []string) (map[string]imageapi.ImageDetails, error) { + s := auth.GetSession(ctx, userCred, "") + params := mcclientoptions.BaseListOptions{} + params.Id = templateIds + limit := len(templateIds) + params.Limit = &limit + params.Scope = "maxallowed" + results, err := imagemodules.Images.List(s, jsonutils.Marshal(params)) + if err != nil { + return nil, errors.Wrap(err, "Images.List") + } + templates := make(map[string]imageapi.ImageDetails) + for i := range results.Data { + tmpl := imageapi.ImageDetails{} + err := results.Data[i].Unmarshal(&tmpl) + if err == nil { + templates[tmpl.Id] = tmpl + } + } + return templates, nil +} diff --git a/pkg/llm/models/mcp_agent.go b/pkg/llm/models/mcp_agent.go index cbc7c927af..414ff89162 100644 --- a/pkg/llm/models/mcp_agent.go +++ b/pkg/llm/models/mcp_agent.go @@ -53,6 +53,9 @@ type SMCPAgentManager struct { type SMCPAgent struct { db.SSharableVirtualResourceBase + // LLMId 关联的 LLM 实例 ID + LLMId string `width:"128" charset:"ascii" nullable:"true" list:"user" create:"optional" update:"user"` + // LLMUrl 对应后端大模型的 base 请求地址 LLMUrl string `width:"512" charset:"utf8" nullable:"false" list:"user" create:"required" update:"user"` // LLMDriver 对应使用的大模型驱动(llm_client),现在可以被设置为 ollama 或 openai @@ -90,6 +93,7 @@ func (man *SMCPAgentManager) ValidateCreateData(ctx context.Context, userCred mc return input, errors.Wrapf(err, "fetch LLM by id %s", input.LLMId) } llm := llmObj.(*SLLM) + input.LLMId = llm.Id llmUrl, err := llm.GetLLMUrl(ctx, userCred) if err != nil { return input, errors.Wrapf(err, "get LLM URL from LLM %s", input.LLMId) @@ -100,7 +104,9 @@ func (man *SMCPAgentManager) ValidateCreateData(ctx context.Context, userCred mc if err != nil { return input, errors.Wrapf(err, "get LLM Sku from LLM %s", input.LLMId) } - input.Model = sku.LLMModelName + if len(input.Model) == 0 { + input.Model = sku.LLMModelName + } } // 验证 llm_url 不为空 @@ -202,14 +208,29 @@ func (manager *SMCPAgentManager) FetchCustomizeColumns( agents := []SMCPAgent{} jsonutils.Update(&agents, objs) + llmIds := make([]string, 0) + for i := range agents { + if len(agents[i].LLMId) > 0 { + llmIds = append(llmIds, agents[i].LLMId) + } + } + + var llmIdNameMap map[string]string + if len(llmIds) > 0 { + var err error + llmIdNameMap, err = db.FetchIdNameMap2(GetLLMManager(), llmIds) + if err != nil { + log.Errorf("FetchIdNameMap2 for LLMs failed: %v", err) + } + } + for i := range rows { rows[i].SharableVirtualResourceDetails = vrows[i] if i < len(agents) { - rows[i].LLMUrl = agents[i].LLMUrl - rows[i].LLMDriver = agents[i].LLMDriver - rows[i].Model = agents[i].Model - rows[i].ApiKey = agents[i].ApiKey - rows[i].McpServer = agents[i].McpServer + rows[i].LLMId = agents[i].LLMId + if name, ok := llmIdNameMap[agents[i].LLMId]; ok { + rows[i].LLMName = name + } } } @@ -276,13 +297,8 @@ func (mcp *SMCPAgent) GetDetailsToolRequest( func (mcp *SMCPAgent) GetDetailsChatStream( ctx context.Context, userCred mcclient.TokenCredential, - input api.LLMChatTestInput, + input api.LLMMCPAgentRequestInput, ) (jsonutils.JSONObject, error) { - llmClient := mcp.GetLLMClientDriver() - if llmClient == nil { - return nil, errors.Error("failed to get LLM client driver") - } - appParams := appsrv.AppContextGetParams(ctx) if appParams == nil { return nil, errors.Error("failed to get app params") @@ -292,51 +308,38 @@ func (mcp *SMCPAgent) GetDetailsChatStream( w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.Header().Set("X-Accel-Buffering", "no") if f, ok := w.(http.Flusher); ok { f.Flush() + } else { + return nil, errors.Error("Streaming unsupported!") } - message := llmClient.NewUserMessage(input.Message) - - err := llmClient.ChatStream(ctx, mcp, []ILLMChatMessage{message}, nil, func(chunk ILLMChatResponse) error { - content := chunk.GetContent() + _, err := mcp.process(ctx, userCred, &input, func(content string) error { if len(content) > 0 { - fmt.Fprintf(w, "%s", content) + for line := range strings.SplitSeq(content, "\n") { + fmt.Fprintf(w, "data: %s\n", line) + } + fmt.Fprintf(w, "\n") if f, ok := w.(http.Flusher); ok { f.Flush() } } return nil }) + if err != nil { - fmt.Fprintf(w, "\nError: %v\n", err) + fmt.Fprintf(w, "data: Error: %v\n\n", err) } return nil, nil } -func (mcp *SMCPAgent) GetDetailsRequest( - ctx context.Context, - userCred mcclient.TokenCredential, - input api.LLMMCPAgentRequestInput, -) (jsonutils.JSONObject, error) { - // 调用 ProcessMCPAgentRequest - answer, err := mcp.process(ctx, userCred, &input) - if err != nil { - return nil, errors.Wrap(err, "process MCP agent request") - } - - // 返回结果 - result := map[string]interface{}{ - "answer": answer.Answer, - } - return jsonutils.Marshal(result), nil -} - // process 处理用户请求 -func (mcp *SMCPAgent) process(ctx context.Context, userCred mcclient.TokenCredential, req *api.LLMMCPAgentRequestInput) (*api.MCPAgentResponse, error) { +// 强制分为两个阶段: +// 阶段一:使用 Chat 非流式获取工具调用参数,并执行工具 +// 阶段二:使用 ChatStream 流式获取最终响应 +func (mcp *SMCPAgent) process(ctx context.Context, userCred mcclient.TokenCredential, req *api.LLMMCPAgentRequestInput, onStream func(string) error) (*api.MCPAgentResponse, error) { // 获取 MCP Server 的工具列表 mcpClient := utils.NewMCPClient(mcp.McpServer, 10*time.Minute, userCred) defer mcpClient.Close() @@ -357,77 +360,110 @@ func (mcp *SMCPAgent) process(ctx context.Context, userCred mcclient.TokenCreden // 构建系统提示词 systemPrompt := buildSystemPrompt() - // 初始化消息历史,使用接口类型 + // 初始化消息历史 messages := []ILLMChatMessage{ llmClient.NewSystemMessage(systemPrompt), - llmClient.NewUserMessage(req.Query), + llmClient.NewUserMessage(req.Message), } // 记录工具调用 var toolCallRecords []api.MCPAgentToolCallRecord - // Agent 循环 - for i := 0; i < api.MCPAgentMaxIterations; i++ { - log.Infof("Agent iteration %d", i+1) - - // 调用 LLM 客户端,传入接口类型 - resp, err := llmClient.Chat(ctx, mcp, messages, tools) - if err != nil { - return nil, errors.Wrap(err, "chat with LLM client") - } - - // 检查是否有工具调用 - if !resp.HasToolCalls() { - // 没有工具调用,返回最终答案 - return &api.MCPAgentResponse{ - Success: true, - Answer: resp.GetContent(), - ToolCalls: toolCallRecords, - }, nil - } - - // 处理工具调用 - toolCalls := resp.GetToolCalls() - log.Infof("Got %d tool calls from LLM", len(toolCalls)) - - // 添加助手消息(带工具调用),使用接口类型 - messages = append(messages, llmClient.NewAssistantMessageWithToolCalls(toolCalls)) - - // 执行每个工具调用 - for _, tc := range toolCalls { - fc := tc.GetFunction() - toolName := fc.GetName() - arguments := fc.GetArguments() - - // 确保 arguments 不为 nil - if arguments == nil { - arguments = make(map[string]interface{}) - } - - log.Infof("Calling tool: %s with arguments: %v", toolName, arguments) - - // 调用 MCP 工具 - result, err := mcpClient.CallTool(ctx, toolName, arguments) - resultText := utils.FormatToolResult(toolName, result, err) - log.Infoln("Get result from mcp query", resultText) - - // 记录工具调用 - toolCallRecords = append(toolCallRecords, api.MCPAgentToolCallRecord{ - ToolName: toolName, - Arguments: arguments, - Result: resultText, - }) - - // 添加工具结果消息,使用接口类型 - messages = append(messages, llmClient.NewToolMessage(tc.GetId(), toolName, resultText)) - } + log.Infof("Phase 1: Thinking & Acting...") + resp, err := llmClient.Chat(ctx, mcp, messages, tools) + if err != nil { + return nil, errors.Wrap(err, "phase 1 chat error") + } + + // 检查是否有工具调用 + if !resp.HasToolCalls() { + // 如果阶段一没有调用工具,模拟推流返回结果 + content := resp.GetContent() + if onStream != nil && len(content) > 0 { + // 模拟流式输出:按字符逐块推送 + chunkSize := 10 // 每次推送10个字符 + for i := 0; i < len(content); i += chunkSize { + end := i + chunkSize + if end > len(content) { + end = len(content) + } + chunk := content[i:end] + if err := onStream(chunk); err != nil { + return nil, errors.Wrap(err, "stream content error") + } + // 添加小延迟模拟真实流式输出 + time.Sleep(10 * time.Millisecond) + } + } + return &api.MCPAgentResponse{ + Success: true, + Answer: content, + ToolCalls: toolCallRecords, + }, nil + } + + // 处理工具调用 + toolCalls := resp.GetToolCalls() + log.Infof("Got %d tool calls from Phase 1", len(toolCalls)) + + // 将助手决定调用工具的消息加入历史 + messages = append(messages, llmClient.NewAssistantMessageWithToolCalls(toolCalls)) + + // 执行每个工具调用 + for _, tc := range toolCalls { + fc := tc.GetFunction() + toolName := fc.GetName() + arguments := fc.GetArguments() + + if arguments == nil { + arguments = make(map[string]interface{}) + } + + log.Infof("Calling tool: %s with arguments: %v", toolName, arguments) + + // 调用 MCP 工具 + result, err := mcpClient.CallTool(ctx, toolName, arguments) + resultText := utils.FormatToolResult(toolName, result, err) + log.Infoln("Get result from mcp query", resultText) + + // 记录 + toolCallRecords = append(toolCallRecords, api.MCPAgentToolCallRecord{ + ToolName: toolName, + Arguments: arguments, + Result: resultText, + }) + + // 将工具执行结果加入历史 + messages = append(messages, llmClient.NewToolMessage(tc.GetId(), toolName, resultText)) + } + + log.Infof("Phase 2: Streaming Response...") + + var finalAnswer strings.Builder + + err = llmClient.ChatStream(ctx, mcp, messages, tools, func(chunk ILLMChatResponse) error { + content := chunk.GetContent() + if len(content) > 0 { + // 聚合最终答案 + finalAnswer.WriteString(content) + + // 实时流式输出 + if onStream != nil { + if err := onStream(content); err != nil { + return err + } + } + } + return nil + }) + + if err != nil { + return nil, errors.Wrap(err, "phase 2 stream error") } - // 达到最大迭代次数 return &api.MCPAgentResponse{ - Success: false, - Answer: "处理请求时达到最大迭代次数,请尝试简化您的问题。", - Error: "max iterations reached", + Success: true, + Answer: finalAnswer.String(), ToolCalls: toolCallRecords, }, nil } diff --git a/pkg/llm/models/sku.go b/pkg/llm/models/sku.go index 46a6348a05..a7da291d88 100644 --- a/pkg/llm/models/sku.go +++ b/pkg/llm/models/sku.go @@ -12,8 +12,6 @@ import ( "yunion.io/x/onecloud/pkg/cloudcommon/db" "yunion.io/x/onecloud/pkg/httperrors" "yunion.io/x/onecloud/pkg/mcclient" - "yunion.io/x/onecloud/pkg/mcclient/auth" - compute "yunion.io/x/onecloud/pkg/mcclient/modules/compute" ) func NewSLLMSkuBaseManager(dt interface{}, tableName string, keyword string, keywordPlural string) SLLMSkuBaseManager { @@ -34,7 +32,7 @@ type SLLMSkuBaseManager struct { type SLLMSkuBase struct { db.SSharableVirtualResourceBase - BandwidthMb int `nullable:"false" default:"0" create:"optional" list:"user" update:"user"` + Bandwidth int `nullable:"false" default:"0" create:"optional" list:"user" update:"user"` Cpu int `nullable:"false" default:"1" create:"optional" list:"user" update:"user"` Memory int `nullable:"false" default:"512" create:"optional" list:"user" update:"user"` Volumes *api.Volumes `charset:"utf8" length:"medium" nullable:"true" list:"user" update:"user" create:"optional"` @@ -43,9 +41,6 @@ type SLLMSkuBase struct { Envs *api.Envs `charset:"utf8" nullable:"true" list:"user" update:"user" create:"optional"` // Properties Properties map[string]string `charset:"utf8" nullable:"true" list:"user" update:"user" create:"optional"` - - NetworkType string `charset:"utf8" list:"user" update:"user" create:"optional"` - NetworkId string `charset:"utf8" nullable:"true" list:"user" update:"user" create:"optional"` } func (man *SLLMSkuBaseManager) ListItemFilter( @@ -78,20 +73,6 @@ func (man *SLLMSkuBaseManager) ValidateCreateData(ctx context.Context, userCred return input, errors.Wrap(httperrors.ErrInputParameter, "volumes cannot be empty") } - if !api.IsLLMSkuBaseNetworkType(input.NetworkType) { - return input, errors.Wrapf(httperrors.ErrInputParameter, "invalid network type %s", input.NetworkType) - } - - if len(input.NetworkId) > 0 { - s := auth.GetSession(ctx, userCred, "") - netObj, err := compute.Networks.Get(s, input.NetworkId, nil) - if err != nil { - return input, errors.Wrapf(httperrors.ErrInputParameter, "invalid network_id %s", input.NetworkId) - } - input.NetworkId, _ = netObj.GetString("id") - input.NetworkType, _ = netObj.GetString("server_type") - } - input.Status = api.STATUS_READY return input, nil } @@ -130,21 +111,5 @@ func (skuBase *SLLMSkuBase) ValidateUpdateData(ctx context.Context, userCred mcc } input.Volumes = (*api.Volumes)(&volumes) - if input.NetworkType != nil && !api.IsLLMSkuBaseNetworkType(*input.NetworkType) { - return input, errors.Wrapf(httperrors.ErrInputParameter, "invalid network type %s", *input.NetworkType) - } - - if input.NetworkId != nil && len(*input.NetworkId) > 0 { - s := auth.GetSession(ctx, userCred, "") - netObj, err := compute.Networks.Get(s, *input.NetworkId, nil) - if err != nil { - return input, errors.Wrapf(httperrors.ErrInputParameter, "invalid network_id %s", *input.NetworkId) - } - netId, _ := netObj.GetString("id") - netType, _ := netObj.GetString("server_type") - input.NetworkId = &netId - input.NetworkType = &netType - } - return input, nil } diff --git a/pkg/llm/service/handler.go b/pkg/llm/service/handler.go index 782f8e682a..e38249c697 100644 --- a/pkg/llm/service/handler.go +++ b/pkg/llm/service/handler.go @@ -1,6 +1,9 @@ package service import ( + "context" + "net/http" + "yunion.io/x/onecloud/pkg/appsrv" "yunion.io/x/onecloud/pkg/appsrv/dispatcher" "yunion.io/x/onecloud/pkg/cloudcommon/db" @@ -8,12 +11,20 @@ import ( "yunion.io/x/onecloud/pkg/llm/models" ) +func handleOllamaRegistryYAML(ctx context.Context, w http.ResponseWriter, r *http.Request) { + yamlContent := models.GetInstantModelManager().GetOllamaRegistryYAML() + w.Header().Set("Content-Type", "application/x-yaml; charset=utf-8") + appsrv.Send(w, yamlContent) +} + func InitHandlers(app *appsrv.Application, isSlave bool) { db.InitAllManagers() db.RegistUserCredCacheUpdater() taskman.AddTaskHandler("", app, isSlave) + app.AddHandler("GET", "/ollama-registry.yaml", handleOllamaRegistryYAML) + for _, manager := range []db.IModelManager{ taskman.TaskManager, taskman.SubTaskManager, diff --git a/pkg/llm/tasks/llm/llm_start_save_model_image_task.go b/pkg/llm/tasks/llm/llm_start_save_model_image_task.go index 3e82d29f63..ce9cdef387 100644 --- a/pkg/llm/tasks/llm/llm_start_save_model_image_task.go +++ b/pkg/llm/tasks/llm/llm_start_save_model_image_task.go @@ -93,14 +93,11 @@ func (task *LLMStartSaveModelImageTask) OnSaveModelImageComplete(ctx context.Con task.SetStageComplete(ctx, nil) - // if input.AutoRestart { - // llm.StartRestartTask(ctx, task.UserCred, api.DesktopRestartTaskInput{ - // DesktopId: llm.Id, - // DesktopStatus: api.LLM_STATUS_READY, - // }, "") - // } else { - // llm.SetStatus(ctx, task.UserCred, api.LLM_STATUS_READY, "OnSaveModelImageComplete") - // } + if input.AutoRestart { + llm.StartStartTask(ctx, task.UserCred, "") + } else { + llm.SetStatus(ctx, task.UserCred, api.LLM_STATUS_READY, "OnSaveModelImageComplete") + } } func (task *LLMStartSaveModelImageTask) OnSaveModelImageCompleteFailed(ctx context.Context, obj db.IStandaloneModel, err jsonutils.JSONObject) { diff --git a/pkg/mcclient/options/llm/image.go b/pkg/mcclient/options/llm/image.go index b49b31a6df..214effc043 100644 --- a/pkg/mcclient/options/llm/image.go +++ b/pkg/mcclient/options/llm/image.go @@ -17,6 +17,8 @@ func (o *LLMImageShowOptions) Params() (jsonutils.JSONObject, error) { type LLMImageListOptions struct { options.BaseListOptions + + LLMType string `json:"llm_type" choices:"ollama|dify" help:"filter by llm type"` } func (o *LLMImageListOptions) Params() (jsonutils.JSONObject, error) { @@ -25,9 +27,10 @@ func (o *LLMImageListOptions) Params() (jsonutils.JSONObject, error) { type LLMImageCreateOptions struct { apis.SharableVirtualResourceCreateInput - IMAGE_NAME string - IMAGE_LABEL string - CredentialId string + IMAGE_NAME string `json:"image_name"` + IMAGE_LABEL string `json:"image_label"` + CredentialId string `json:"credential_id"` + LLM_TYPE string `json:"llm_type" choices:"ollama|dify" help:"llm type: ollama or dify"` } func (o *LLMImageCreateOptions) Params() (jsonutils.JSONObject, error) { @@ -35,12 +38,13 @@ func (o *LLMImageCreateOptions) Params() (jsonutils.JSONObject, error) { } type LLMImageUpdateOptions struct { - apis.SharableVirtualResourceCreateInput + apis.SharableVirtualResourceBaseUpdateInput ID string - IMAGE_NAME string - IMAGE_LABEL string - CredentialId string + ImageName string `json:"image_name"` + ImageLabel string `json:"image_label"` + CredentialId string `json:"credential_id"` + LlmType string `json:"llm_type" choices:"ollama|dify" help:"llm type: ollama or dify"` } func (o *LLMImageUpdateOptions) GetId() string { diff --git a/pkg/mcclient/options/llm/llm.go b/pkg/mcclient/options/llm/llm.go index f177d75f72..d3c0f35b10 100644 --- a/pkg/mcclient/options/llm/llm.go +++ b/pkg/mcclient/options/llm/llm.go @@ -1,10 +1,7 @@ package llm import ( - "strings" - "yunion.io/x/jsonutils" - "yunion.io/x/pkg/util/regutils" api "yunion.io/x/onecloud/pkg/apis/llm" "yunion.io/x/onecloud/pkg/mcclient/options" @@ -15,6 +12,9 @@ type LLMBaseListOptions struct { Host string `help:"filter by host"` LLMStatus []string `help:"filter by server status"` + NetworkType string `help:"filter by network type"` + NetworkId string `help:"filter by network id"` + ListenPort int `help:"filter by listen port"` PublicIp string `help:"filter by public ip"` VolumeId string `help:"filter by volume id"` @@ -55,6 +55,9 @@ type LLMBaseCreateOptions struct { ProjectId string PreferHost string + NETWORK_TYPE string `json:"network_type" choices:"guest|hostlocal"` + NetworkId string `help:"id of network" json:"network_id"` + BandwidthMb int Count int `default:"1" help:"batch create count" json:"-"` @@ -120,13 +123,13 @@ type LLMSaveInstantModelOptions struct { MODEL_ID string `help:"llm model id, e.g. 500a1f067a9f"` Name string `help:"instant app name, e.g. qwen3:8b"` - // AutoRestart bool + AutoRestart bool } func (opts *LLMSaveInstantModelOptions) Params() (jsonutils.JSONObject, error) { input := api.LLMSaveInstantModelInput{ - ModelId: opts.MODEL_ID, - ImageName: opts.Name, + ModelId: opts.MODEL_ID, + ModelFullName: opts.Name, // AutoRestart: opts.AutoRestart, } return jsonutils.Marshal(input), nil @@ -135,47 +138,18 @@ func (opts *LLMSaveInstantModelOptions) Params() (jsonutils.JSONObject, error) { type LLMQuickModelsOptions struct { LLMIdOptions - MODEL []string `help:"model id and optional display name in the format of modelId[@modelName:modelTag], e.g. 6f48b936a09f or 6f48b936a09f@qwen2:0.5b"` + MODEL []string `help:"model id of instant model, e.g. qwen3:0.6b-251202 or 7f72b5a1-4049-43db-8e91-8dee736ae1ac"` Method string `help:"install or uninstall" choices:"install|uninstall"` } func (opts *LLMQuickModelsOptions) Params() (jsonutils.JSONObject, error) { params := api.LLMPerformQuickModelsInput{} - for _, mdlFul := range opts.MODEL { - var mdl api.ModelInfo - - var idPart string - var nameAndTagPart string - - if idx := strings.Index(mdlFul, "@"); idx >= 0 { - idPart = mdlFul[:idx] - nameAndTagPart = mdlFul[idx+1:] - - if idxTag := strings.LastIndex(nameAndTagPart, ":"); idxTag >= 0 { - mdl.DisplayName = nameAndTagPart[:idxTag] - mdl.Tag = nameAndTagPart[idxTag+1:] - } else { - mdl.DisplayName = nameAndTagPart - } - } else { - idPart = mdlFul - - if idxTag := strings.LastIndex(idPart, ":"); idxTag >= 0 { - mdl.Tag = idPart[idxTag+1:] - idPart = idPart[:idxTag] - } - } - - if regutils.MatchUUID(idPart) { - mdl.Id = idPart - } else { - mdl.ModelId = idPart - } - - params.Models = append(params.Models, mdl) + for _, mdl := range opts.MODEL { + params.Models = append(params.Models, api.ModelInfo{ + Id: mdl, + }) } - if len(opts.Method) > 0 { params.Method = api.TQuickModelMethod(opts.Method) } diff --git a/pkg/mcclient/options/llm/llm_sku_base.go b/pkg/mcclient/options/llm/llm_sku_base.go index a3f46f6754..ef57e3f1e2 100644 --- a/pkg/mcclient/options/llm/llm_sku_base.go +++ b/pkg/mcclient/options/llm/llm_sku_base.go @@ -19,10 +19,8 @@ type LLMSkuBaseCreateOptions struct { MEMORY int `help:"memory size MB"` DISK_SIZE int `help:"disk size MB"` - NETWORK_TYPE string `json:"network_type" choices:"guest|hostlocal"` - NetworkId string `help:"id of network" json:"network_id"` - Bandwidth int - StorageType string + Bandwidth int + StorageType string // DiskOverlay string `help:"disk overlay, e.g. /opt/steam-data/base:/opt/steam-data/games"` TemplateId string PortMappings []string `help:"port mapping in the format of protocol:port[:prefix][:first_port_offset][:env_key=env_value], e.g. tcp:5555:192.168.0.0/16:5:WOLF_BASE_PORT=20000"` @@ -62,9 +60,7 @@ type LLMSkuBaseUpdateOptions struct { DiskSize *int `help:"disk size MB"` StorageType string TemplateId string - NoTemplate bool `json:"-" help:"remove template"` - NetworkType string `json:"network_type" choices:"guest|hostlocal"` - NetworkId string `help:"id of network" json:"network_id"` + NoTemplate bool `json:"-" help:"remove template"` Bandwidth *int // Dpi *int // Fps *int diff --git a/pkg/mcclient/options/llm/mcp_agent.go b/pkg/mcclient/options/llm/mcp_agent.go index b9fb95bbc1..dd19704151 100644 --- a/pkg/mcclient/options/llm/mcp_agent.go +++ b/pkg/mcclient/options/llm/mcp_agent.go @@ -140,28 +140,15 @@ func (opts *MCPAgentToolRequestOptions) Params() (jsonutils.JSONObject, error) { return jsonutils.Marshal(input), nil } -type MCPAgentChatTestOptions struct { - MCPAgentIdOptions - - Message string `help:"test message to send to LLM" json:"message"` -} - -func (opts *MCPAgentChatTestOptions) Params() (jsonutils.JSONObject, error) { - input := api.LLMChatTestInput{ - Message: opts.Message, - } - return jsonutils.Marshal(input), nil -} - type MCPAgentMCPAgentRequestOptions struct { MCPAgentIdOptions - Query string `help:"query to send to MCP agent" json:"query"` + Message string `help:"message to send to MCP agent" json:"message"` } func (opts *MCPAgentMCPAgentRequestOptions) Params() (jsonutils.JSONObject, error) { input := api.LLMMCPAgentRequestInput{ - Query: opts.Query, + Message: opts.Message, } return jsonutils.Marshal(input), nil } diff --git a/scripts/sync_dify_images.sh b/scripts/sync_dify_images.sh index 3d6b7e25d4..3db518fb28 100644 --- a/scripts/sync_dify_images.sh +++ b/scripts/sync_dify_images.sh @@ -51,7 +51,7 @@ for image in "${IMAGES[@]}"; do echo " Target: ${DST}" echo - skopeo copy "${SRC}" "${DST}" + skopeo copy --override-os linux --override-arch amd64 "${SRC}" "${DST}" echo "Completed: ${short_name}:${tag}" done