diff --git a/cmd/mcp-server/main.go b/cmd/mcp-server/main.go new file mode 100644 index 0000000000..41bb465b8f --- /dev/null +++ b/cmd/mcp-server/main.go @@ -0,0 +1,23 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "yunion.io/x/onecloud/pkg/mcp-server/service" +) + +func main() { + service.StartService() +} diff --git a/docs/mcp-server/README.md b/docs/mcp-server/README.md new file mode 100644 index 0000000000..fea2dd6679 --- /dev/null +++ b/docs/mcp-server/README.md @@ -0,0 +1,59 @@ +# MCP Server + +MCP Server 是 Cloudpods 多云管理平台的核心组件之一,负责处理多云资源的统一管理和调度。 + +## 目录结构 + +``` +├── adapters/ # 适配器模块,用于对接不同云平台的API +├── config/ # 配置模块,处理服务配置和加载 +├── models/ # 数据模型,定义云资源的数据结构 +├── registry/ # 注册中心,管理可用的工具和服务 +├── server/ # 服务核心,包含服务启动和初始化逻辑 +└── tools/ # 工具模块,实现各种云资源管理功能 +``` + +## 架构设计 + +MCP Server 采用模块化设计,主要包括以下几个核心模块: + +1. **适配器模块 (Adapters)**: 负责与不同云平台的API进行交互,实现资源的统一管理。 +2. **配置模块 (Config)**: 处理服务的配置加载和管理,支持多种配置方式。 +3. **数据模型 (Models)**: 定义云资源的数据结构,为其他模块提供统一的数据访问接口。 +4. **注册中心 (Registry)**: 管理可用的工具和服务,支持动态注册和发现。 +5. **服务核心 (Server)**: 负责服务的启动、初始化和生命周期管理。 +6. **工具模块 (Tools)**: 实现各种云资源管理功能,如VPC、网络、镜像等。 + +## 运行机制 + +1. 服务启动时,首先加载配置文件并初始化各个模块。 +2. 适配器模块根据配置连接到相应的云平台。 +3. 注册中心注册所有可用的工具和服务。 +4. 服务核心启动HTTP服务器,监听客户端请求。 +5. 客户端通过API调用相应的工具来管理云资源。 + +## 主要功能 + +- 统一管理多云资源(VPC、网络、镜像、主机等) +- 支持多种云平台(AWS、Azure、阿里云等) +- 提供RESTful API接口 +- 支持资源的查询、创建、更新和删除操作 + +## 配置说明 + +配置文件位于 `options/options.go`,主要包含以下配置项: + +- ServerConfig: 服务配置,如监听地址、端口等 +- MCPConfig: MCP相关配置 +- ExternalConfig: 外部服务配置 + +## 开发指南 + +1. 实现新的云资源管理功能时,需要在 `tools/` 目录下创建相应的工具文件。 +2. 工具需要实现 `Tool` 接口,包括 `GetTool`、`Handle` 和 `GetName` 方法。 +3. 数据模型定义在 `models/` 目录下,需要根据云平台API文档进行定义。 +4. 适配器实现在 `adapters/` 目录下,用于与云平台API进行交互。 + +## 贡献 + +欢迎提交Issue和Pull Request来改进MCP Server。 \ No newline at end of file diff --git a/docs/mcp-server/images/mcp-server-tools.png b/docs/mcp-server/images/mcp-server-tools.png new file mode 100644 index 0000000000..6ddbaf1e2b Binary files /dev/null and b/docs/mcp-server/images/mcp-server-tools.png differ diff --git a/docs/mcp-server/使用文档.md b/docs/mcp-server/使用文档.md new file mode 100644 index 0000000000..4877cbbf0e --- /dev/null +++ b/docs/mcp-server/使用文档.md @@ -0,0 +1,732 @@ +# mcp-server使用文档 + +- 使用市面上的主流mcp-server客户端,如cline、cursor等均可 +- 这里以cline为例 + + + +## 配置mcp-server + +可以选择使用stdio模式或者是sse模式,源码中默认使用sse模式 + +- sse模式: + +````json +{ + "mcpServers": { + "cloudpods-sse": { + "url": "http://localhost:12001/sse" + } + } +} +```` +- stdio模式: +````json +{ + "mcpServers": { + "mcp-server": { + "disabled": false, + "timeout": 60, + "type": "stdio", + "command": "D:/ospp/cloudpods/cmd/mcp-server/mcp-server.exe", + "args": [ + "config", + "D:/ospp/cloudpods/pkg/mcp-server/config/mcp-server.yaml" + ] + } + } +} +```` + + + +![mcp-server-tools](./images/mcp-server-tools.png) + +## 通过大模型和mcp-server交互 + +1. 登录cloudpods前端界面 +2. 获取accessKey和secretKey + +这里通过命令行和climc工具创建ak和sk + +````bash +climc credential-create-aksk +```` + +3. 在和AI对话时提供accessKey和secretKey + +## 使用用例示例 + +目前mcp-server总共提供了15个功能: + +资源查询: + +- cloudpods_list_images: 查询镜像列表 +- cloudpods_list_networks: 查询网络列表 +- cloudpods_list_regions:查询区域列表 +- cloudpods_list_servers:查询虚拟机实例列表 +- cloudpods_list_serverskus:查询服务器规格列表 +- cloudpods_list_storages:查询存储列表 +- cloudpods_list_vpcs:查询 VPC 列表 + +资源操作: + +- cloudpods_create_server:创建虚拟机实例 +- cloudpods_delete_server:删除虚拟机实例 +- cloudpods_start_server:启动虚拟机实例 +- cloudpods_stop_server:停止虚拟机实例 +- cloudpods_reset_server_password:重置虚拟机密码 +- cloudpods_get_server_monitor:获取Cloudpods虚拟机监控信息 +- cloudpods_get_server_stats:获取Cloudpods虚拟机实时统计信息 + +## 测试环境信息 +- **认证信息**: + - Access Key: `73e97540cabe4a7580fc469760df5e80` + - Secret Key: `enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=` +- **Cloudpods 实例**: `https://10.21.76.40` + +--- + +## 1. cloudpods_list_images - 查询镜像列表 + +### 用户提示词示例 + +- "帮我查询云平台的镜像列表,要求显示前10条(从第1条开始),筛选条件为操作系统类型是Linux或Windows,关键词包含'centos'。" +- "列出当前账号下可用的镜像,最多返回5条,跳过前3条(即从第4条开始),只显示类型为Ubuntu的镜像。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "limit": "10", + "offset": "0", + "search": "", + "os_types": "Linux,Windows" +} +``` + +### 返回结果 +```json +{ + "images": [], + "query_info": { + "count": 0, + "limit": 10, + "offset": 0, + "os_types": ["Linux", "Windows"], + "search": "", + "total": 0 + }, + "summary": { + "has_more": false, + "next_offset": 0, + "returned_count": 0, + "total_images": 0 + } +} +``` + +--- + +## 2. cloudpods_list_networks - 查询网络列表 + +### 用户提示词示例 + +- "查询VPC(ID: vpc-default)下的网络列表,显示前5条(从第1条开始),筛选名称包含'生产环境'的网络。" +- "列出当前区域默认VPC中的网络,最多返回8条,跳过前2条,关键词为'测试'。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "limit": "10", + "offset": "0", + "search": "", + "vpc_id": "default" +} +``` + +### 返回结果 +```json +{ + "networks": [], + "query_info": { + "count": 0, + "limit": 10, + "offset": 0, + "search": "", + "total": 0, + "vpc_id": "default" + }, + "summary": { + "has_more": false, + "next_offset": 0, + "returned_count": 0, + "total_networks": 0 + } +} +``` + +--- + +## 3. cloudpods_list_regions - 查询区域列表 + +### 用户提示词示例 + +- "查询云服务商OneCloud支持的所有区域,显示前10条(从第1条开始),筛选名称包含'华北'的区域。" +- "列出提供商为OneCloud的区域列表,最多返回3条,跳过前0条(即从第1条开始),关键词为空。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "limit": "10", + "offset": "0", + "search": "", + "provider": "OneCloud" +} +``` + +### 返回结果 +```json +{ + "cloudregions": [ + { + "can_delete": false, + "can_update": true, + "city": "", + "cloud_env": "", + "country_code": "", + "created_at": "2025-08-27T08:52:17Z", + "description": "Default Region", + "enabled": true, + "environment": "", + "external_id": "", + "guest_count": 0, + "guest_increment_count": 0, + "id": "default", + "imported_at": "2025-08-27T08:52:17Z", + "is_emulated": false, + "latitude": 0, + "longitude": 0, + "metadata": null, + "name": "Default", + "network_count": 0, + "progress": 100, + "provider": "OneCloud", + "source": "local", + "status": "inservice", + "updated_at": "2025-08-27T08:52:17Z", + "vpc_count": 1, + "zone_count": 1 + } + ], + "query_info": { + "count": 1, + "limit": 10, + "offset": 0, + "provider": "OneCloud", + "search": "", + "total": 1 + }, + "summary": { + "has_more": false, + "next_offset": 1, + "returned_count": 1, + "total_cloudregions": 1 + } +} +``` + +--- + +## 4. cloudpods_list_servers - 查询虚拟机实例列表 + +### 用户提示词示例 + +- "查询当前账号下状态为'运行中'的虚拟机,显示前5条(从第1条开始),筛选名称包含'web-server'的实例。" +- "列出状态为'stopped'的虚拟机实例,最多返回8条,跳过前3条(即从第4条开始),关键词为'test'。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "limit": "5", + "offset": "0", + "search": "", + "status": "" +} +``` + +### 返回结果 +```json +{ + "query_info": { + "count": 0, + "limit": 5, + "offset": 0, + "search": "", + "status": "", + "total": 0 + }, + "servers": [], + "summary": { + "returned_count": 0, + "total_servers": 0 + } +} +``` + +--- + +## 5. cloudpods_list_serverskus - 查询服务器规格列表 + +### 用户提示词示例 + +- "查询默认区域(default)下,CPU核心数为2或4,内存大小为4096MB或8192MB的x86架构服务器规格,显示前10条(从第1条开始)。" +- "列出云区域(ID: cn-north-1)中,CPU架构为ARM,核心数8,内存16384MB的服务器规格,最多返回5条,跳过前0条。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "limit": "10", + "offset": "0", + "search": "", + "cloudregion_ids": "default", + "zone_ids": "", + "cpu_core_count": "1,2,4,8", + "memory_size_mb": "1024,2048,4096,8192", + "providers": "OneCloud", + "cpu_arch": "x86" +} +``` + +### 返回结果 +```json +{ + "query_info": { + "cloudregion_ids": ["default"], + "count": 9, + "cpu_arch": ["x86"], + "cpu_core_count": ["1", "2", "4", "8"], + "limit": 10, + "memory_size_mb": ["1024", "2048", "4096", "8192"], + "offset": 0, + "providers": ["OneCloud"], + "search": "", + "total": 9, + "zone_ids": null + }, + "serverskus": [ + { + "attached_disk_count": 0, + "attached_disk_size_gb": 0, + "attached_disk_type": "", + "can_delete": true, + "can_update": true, + "cloud_env": "", + "cloudregion": "Default", + "cloudregion_id": "default", + "cpu_arch": "", + "cpu_core_count": 8, + "created_at": "2025-08-27T08:52:17Z", + "data_disk_max_count": 0, + "data_disk_types": "", + "description": "", + "enabled": true, + "external_id": "", + "gpu_attachable": true, + "gpu_count": "", + "gpu_max_count": 0, + "gpu_spec": "", + "id": "a30758a9-457c-4fe1-8939-b50a4df31ebf", + "imported_at": "2025-08-27T08:52:17Z", + "instance_type_category": "general_purpose", + "instance_type_family": "g1", + "is_emulated": false, + "local_category": "general_purpose", + "md5": "", + "memory_size_mb": 8192, + "metadata": null, + "name": "ecs.g1.c8m8", + "nic_max_count": 1, + "nic_type": "", + "os_name": "Any", + "postpaid_status": "available", + "prepaid_status": "available", + "progress": 100, + "provider": "OneCloud", + "region": "Default", + "region_ext_id": "", + "region_external_id": "", + "region_id": "default", + "source": "local", + "status": "init", + "sys_disk_max_size_gb": 0, + "sys_disk_min_size_gb": 0, + "sys_disk_resizable": true, + "sys_disk_type": "", + "total_guest_count": 0, + "update_version": 0, + "updated_at": "2025-08-27T08:52:17Z", + "zone": "", + "zone_ext_id": "", + "zone_id": "" + } + // ... 更多服务器规格数据(共9条记录) + ], + "summary": { + "has_more": false, + "next_offset": 9, + "returned_count": 9, + "total_serverskus": 9 + } +} +``` + +--- + +## 6. cloudpods_list_storages - 查询存储列表 + +### 用户提示词示例 + +- "查询默认区域(default)下类型为'local'的存储资源,显示前10条(从第1条开始),筛选名称包含'system'的存储。" +- "列出云区域(ID: cn-north-1)中,提供商为OneCloud,类型为'block'的存储,最多返回5条,跳过前2条(即从第3条开始)。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "limit": "10", + "offset": "0", + "search": "", + "cloudregion_ids": "default", + "zone_ids": "", + "providers": "OneCloud", + "storage_types": "local", + "host_id": "" +} +``` + +### 返回结果 +```json +{ + "query_info": { + "cloudregion_ids": ["default"], + "count": 0, + "host_id": "", + "limit": 10, + "offset": 0, + "providers": ["OneCloud"], + "search": "", + "storage_types": ["local"], + "total": 0, + "zone_ids": null + }, + "storages": [], + "summary": { + "has_more": false, + "next_offset": 0, + "returned_count": 0, + "total_storages": 0 + } +} +``` + +--- + +## 7. cloudpods_list_vpcs - 查询 VPC 列表 + +### 用户提示词示例 + +- "查询默认区域(default)下的VPC列表,显示前10条(从第1条开始),筛选名称包含'生产'的VPC。" +- "列出云区域(ID: cn-north-2)中的VPC,最多返回3条,跳过前0条(即从第1条开始),关键词为空。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "limit": "10", + "offset": "0", + "search": "", + "cloudregion_id": "default" +} +``` + +### 返回结果 +```json +{ + "query_info": { + "cloudregion_id": "default", + "count": 1, + "limit": 10, + "offset": 0, + "search": "", + "total": 1 + }, + "summary": { + "has_more": false, + "next_offset": 1, + "returned_count": 1, + "total_vpcs": 1 + }, + "vpcs": [ + { + "accept_vpc_peer_count": 0, + "account": "", + "account_health_status": "", + "account_id": "", + "account_status": "", + "brand": "OneCloud", + "can_delete": false, + "can_update": true, + "cidr_block": "", + "cidr_block6": "", + "cloud_env": "onpremise", + "cloudregion": "Default", + "cloudregion_id": "default", + "created_at": "2025-08-27T08:52:17Z", + "description": "Default VPC", + "direct": false, + "dns_zone_count": 0, + "domain_id": "default", + "domain_src": "", + "enabled": false, + "environment": "", + "external_access_mode": "eip-distgw", + "external_id": "", + "globalvpc": "", + "globalvpc_id": "", + "id": "default", + "imported_at": "2025-08-27T08:52:17Z", + "is_default": true, + "is_emulated": false, + "is_public": true, + "manager": "", + "manager_domain": "", + "manager_domain_id": "", + "manager_id": "", + "manager_project": "", + "manager_project_id": "", + "metadata": null, + "name": "Default", + "natgateway_count": 0, + "network_count": 0, + "progress": 100, + "project_domain": "Default", + "provider": "OneCloud", + "public_scope": "system", + "public_src": "", + "region": "Default", + "region_ext_id": "", + "region_external_id": "", + "region_id": "default", + "request_vpc_peer_count": 0, + "routetable_count": 0, + "shared_domains": null, + "shared_projects": null, + "source": "local", + "status": "available", + "updated_at": "2025-08-27T08:52:27Z", + "wire_count": 1 + } + ] +} +``` + +--- + +## 8. cloudpods_create_server - 创建虚拟机实例 + +### 用户提示词示例 + +- "创建一个名为'web-server-01'的虚拟机,配置2核CPU、4GB内存,使用镜像ID'img-centos7',网络ID'net-prod',自动启动,密码设置为'SecurePass123!',备注为'生产环境Web服务器'。" +- "创建一台虚拟机实例,名称为'db-server-01',CPU核心数4,内存8GB,使用镜像ID'img-ubuntu20',网络ID'net-db',不自动启动,密码'Admin@2025',项目ID'proj-123'。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "name": "test-vm", + "vcpu_count": "2", + "vmem_size": "4096", + "image_id": "example-image-id", + "network_id": "example-network-id", + "count": "1", + "auto_start": "true", + "password": "TestPassword123", + "billing_type": "postpaid", + "duration": "", + "description": "Test VM created via MCP", + "hostname": "test-vm", + "hypervisor": "kvm", + "user_data": "", + "keypair_id": "", + "project_id": "", + "zone_id": "", + "region_id": "default", + "disable_delete": "false", + "boot_order": "cdn", + "metadata": "{\"environment\": \"test\", \"owner\": \"admin\"}", + "data_disks": "[{\"size\": 50, \"disk_type\": \"data\"}]", + "secgroup_id": "", + "secgroups": "", + "serversku_id": "" +} +``` + +--- + +## 9. cloudpods_delete_server - 删除虚拟机实例 + +### 用户提示词示例 + +- "删除ID为'vm-001'的虚拟机实例,不删除关联的磁盘和弹性IP。" +- "彻底删除ID为'vm-002'的虚拟机实例(包括所有关联磁盘、快照和弹性IP)。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "server_id": "example-server-id", + "delete_disks": "false", + "delete_eip": "false", + "delete_snapshots": "false", + "override_pending_delete": "false", + "purge": "false" +} +``` + +--- + +## 10. cloudpods_start_server - 启动虚拟机实例 + +### 用户提示词示例 + +- "启动ID为'vm-003'的虚拟机实例,使用默认QEMU版本。" +- "强制启动ID为'vm-004'的虚拟机实例(忽略预检查警告)。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "server_id": "example-server-id", + "auto_prepaid": "false", + "qemu_version": "" +} +``` + +## 11. cloudpods_stop_server - 停止虚拟机实例 + +### 用户提示词示例 + +- + "停止ID为'vm-005'的虚拟机实例(正常关机,不强制),停止计费。" +- "强制停止ID为'vm-006'的虚拟机实例(立即断电),超时时间设置为60秒。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "server_id": "example-server-id", + "is_force": "false", + "stop_charging": "false", + "timeout_secs": "30" +} +``` + +--- + +## 12. cloudpods_restart_server - 重启虚拟机实例 + +### 用户提示词示例 + +- "重启ID为'vm-007'的虚拟机实例(正常重启,不强制)。 +- "强制重启ID为'vm-008'的虚拟机实例(立即中断进程)。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "server_id": "example-server-id", + "is_force": "false" +} +``` + +--- + +## 13. cloudpods_reset_server_password - 重置虚拟机密码 + +### 用户提示词示例 + +- "重置ID为'vm-009'的虚拟机密码为'NewSecurePass456!',并自动启动实例。" +- "重置ID为'vm-010'的虚拟机(用户名为'admin')的密码为'Admin@2025New',不自动启动。" + +### MCP实际接收参数 +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "server_id": "example-server-id", + "password": "TestPassword123!", + "reset_password": "true", + "auto_start": "true", + "username": "" +} +``` + +## 14.cloudpods_get_server_monitor - 获取Cloudpods虚拟机监控信息 + +获取Cloudpods虚拟机监控信息,包括CPU、内存、磁盘、网络等指标 + +### 用户提示词示例 + +- + "获取ID为'vm-011'的虚拟机在2025-08-27 00:00到2025-08-27 23:59期间的CPU使用率、内存使用率和网络流入/流出流量监控数据。" +- "查询ID为'vm-012'的虚拟机最近1小时的CPU使用率(每5分钟采样一次)和磁盘读写速率。" + +### MCP实际接收参数 + +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "server_id": "example-server-id", + "metrics": "cpu_usage,mem_usage,disk_usage,net_bps_rx,net_bps_tx", + "start_time": "1724760000", + "end_time": "1724763600" +} +``` + +## 15.cloudpods_get_server_stats - 获取Cloudpods虚拟机实时统计信息 + +获取Cloudpods虚拟机实时统计信息,包括CPU使用率、内存使用率、磁盘使用率和网络流量 + +### 用户提示词示例 + +- "获取ID为'vm-013'的虚拟机当前的CPU使用率、内存使用率、磁盘总空间/已用空间和网络上下行流量统计信息。" +- "查询ID为'vm-014'的虚拟机实时监控数据,包括CPU负载、内存空闲量、磁盘IO和网络连接数。" + +### MCP实际接收参数 + +```json +{ + "ak": "73e97540cabe4a7580fc469760df5e80", + "sk": "enh2ZXRYZUhWdHpjeHJENVdoQmQyUGNnbXhFS2dneUQ=", + "server_id": "example-server-id" +} +``` + diff --git a/docs/mcp-server/安装文档.md b/docs/mcp-server/安装文档.md new file mode 100644 index 0000000000..adaa649a00 --- /dev/null +++ b/docs/mcp-server/安装文档.md @@ -0,0 +1,52 @@ +--- +sidebar_position: 6 +--- + +# MCP Server部署 + +## MCP Server初始化 + +1) 首先配置MCP Server的配置文件 + +```sh +# 编译mcp-server +$ cd /root/cloudpods && make cmd/mcp-server + +# 编写mcp-server服务的配置文件 +$ mkdir -p /etc/yunion/mcp-server +# 编写配置文件,注意根据实际情况修改Cloudpods API的认证信息 +$ cat</etc/yunion/mcp-server/mcp-server.conf +# ==================== 服务器基础配置 ==================== +address = '127.0.0.1' +port = 12001 + + +# ==================== MCP 服务配置 ==================== +mcp_server_name = cloudpods-mcp-server # MCP 服务名称(默认:cloudpods-mcp-server) +mcp_server_version = 1.0.0 # MCP 服务版本(默认:1.0.0) +mcp_server_description = the mcp server of the cloudpods server # MCP 服务描述(默认) + + +# ==================== 外部服务配置 ==================== +identity_base_url = "https:///api/s/identity/v3" # 认证服务入口 +``` + +2) 启动MCP Server服务 + +```sh +# 启动mcp-server服务 +# 默认会从以下路径查找配置文件: /etc/yunion/mcp-server/mcp-server.yaml, ./config/mcp-server.yaml, ./mcp-server.yaml +$ /root/cloudpods/bin/mcp-server --log-level debug + +# 或者使用 --conf 参数指定配置文件路径 +$ /root/cloudpods/bin/mcp-server --log-level debug --conf /etc/yunion/mcp-server/mcp-server.yaml +``` + +## 验证服务 + +MCP Server启动后,可以通过以下方式验证服务是否正常运行: + +```sh +# 检查服务是否监听在指定端口 +$ curl http://localhost:12001/sse +``` \ No newline at end of file diff --git a/go.mod b/go.mod index 211d1f2cde..1e1a77e90b 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( github.com/lestrrat-go/jwx v1.0.2 github.com/lestrrat/go-jwx v0.0.0-20180221005942-b7d4802280ae github.com/libvirt/libvirt-go-xml v5.2.0+incompatible + github.com/mark3labs/mcp-go v0.39.1 github.com/mattn/go-sqlite3 v1.14.19 github.com/mdlayher/arp v0.0.0-20190313224443-98a83c8a2717 github.com/mdlayher/ethernet v0.0.0-20190606142754-0394541c37b7 @@ -96,7 +97,7 @@ require ( k8s.io/cri-api v0.22.17 k8s.io/klog/v2 v2.20.0 moul.io/http2curl/v2 v2.3.0 - yunion.io/x/cloudmux v0.3.10-0-alpha.1.0.20250912144144-d0d8cf049d7f + yunion.io/x/cloudmux v0.3.10-0-alpha.1.0.20250915054625-7251d9eeceec yunion.io/x/executor v0.0.0-20250518005516-5402e9e0bed0 yunion.io/x/jsonutils v1.0.1-0.20250507052344-1abcf4f443b1 yunion.io/x/log v1.0.1-0.20240305175729-7cf2d6cd5a91 @@ -146,6 +147,7 @@ require ( github.com/aokoli/goutils v1.0.1 // indirect github.com/apache/thrift v0.13.0 // indirect github.com/aws/aws-sdk-go v1.39.0 // indirect + github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/basgys/goxml2json v1.1.1-0.20181031222924-996d9fc8d313 // indirect github.com/beevik/etree v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect @@ -154,6 +156,7 @@ require ( github.com/boltdb/bolt v1.3.1 // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/cespare/xxhash v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/checkpoint-restore/go-criu/v4 v4.1.0 // indirect @@ -187,7 +190,6 @@ require ( github.com/fatih/color v1.13.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 // indirect - github.com/frankban/quicktest v1.14.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/glycerine/go-unsnap-stream v0.0.0-20181221182339-f9677308dec2 // indirect github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect @@ -218,6 +220,7 @@ require ( github.com/huandu/xstrings v1.2.0 // indirect github.com/huaweicloud/huaweicloud-sdk-go v1.0.26 // indirect github.com/imdario/mergo v0.3.6 // indirect + github.com/invopop/jsonschema v0.13.0 // indirect github.com/jdcloud-api/jdcloud-sdk-go v1.55.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/native v1.1.0 // indirect @@ -232,6 +235,7 @@ require ( github.com/lestrrat/go-pdebug v0.0.0-20180220043741-569c97477ae8 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/ma314smith/signedxml v0.0.0-20210628192057-abc5b481ae1c // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.9 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/mattn/go-runewidth v0.0.13 // indirect @@ -281,6 +285,7 @@ require ( github.com/seccomp/libseccomp-golang v0.9.1 // indirect github.com/smartystreets/assertions v1.2.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect + github.com/spf13/cast v1.7.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/syndtr/gocapability v0.0.0-20180916011248-d98352740cb2 // indirect @@ -297,8 +302,10 @@ require ( github.com/volcengine/volc-sdk-golang v1.0.23 // indirect github.com/willf/bitset v1.1.11-0.20200630133818-d5bec3311243 // indirect github.com/willf/bloom v2.0.3+incompatible // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/xuri/efp v0.0.0-20220603152613-6918739fd470 // indirect github.com/xuri/nfp v0.0.0-20220409054826-5e722a1d9e22 // indirect + github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect go.etcd.io/bbolt v1.3.7 // indirect go.etcd.io/etcd/client/pkg/v3 v3.5.0 // indirect diff --git a/go.sum b/go.sum index dccc963f8b..92fbd40594 100644 --- a/go.sum +++ b/go.sum @@ -169,6 +169,8 @@ github.com/avast/retry-go v3.0.0+incompatible/go.mod h1:XtSnn+n/sHqQIpZ10K1qAevB github.com/aws/aws-sdk-go v1.35.24/go.mod h1:tlPOdRjfxPBpNIwqDj61rmsnA85v9jc0Ps9+muhnW+k= github.com/aws/aws-sdk-go v1.39.0 h1:74BBwkEmiqBbi2CGflEh34l0YNtIibTjZsibGarkNjo= github.com/aws/aws-sdk-go v1.39.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f h1:ZNv7On9kyUzm7fvRZumSyy/IUiSC7AzL0I1jKKtwooA= github.com/baiyubin/aliyun-sts-go-sdk v0.0.0-20180326062324-cfa1a18b161f/go.mod h1:AuiFmCCPBSrqvVMvuqFuk0qogytodnVFVSN5CeJB8Gc= github.com/basgys/goxml2json v1.1.1-0.20181031222924-996d9fc8d313 h1:fKPpQHBQgt4dQuG6x+yH4gdgtodFDgN9rvHzwJzTKeg= @@ -195,6 +197,8 @@ github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8 github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2 h1:1B/+1BcRhOMG1KH/YhNIU8OppSWk5d/NGyfRla88CuY= github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/c-bata/go-prompt v0.2.4 h1:7pKUJ3CUgzdu1HJeWhNRkpVyY/NnlJhM/7d6YgHNOao= github.com/c-bata/go-prompt v0.2.4/go.mod h1:PqlttLXp0E7bZcoDW+dmzyKqFbmQTFoNzGSuW/AQRmo= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -322,8 +326,8 @@ github.com/fernet/fernet-go v0.0.0-20180830025343-9eac43b88a5e h1:P10tZmVD2XclAa github.com/fernet/fernet-go v0.0.0-20180830025343-9eac43b88a5e/go.mod h1:2H9hjfbpSMHwY503FclkV/lZTBh2YlOmLLSda12uL8c= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 h1:BHsljHzVlRcyQhjrss6TZTdY2VfCqZPbv5k3iBFa2ZQ= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= -github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= -github.com/frankban/quicktest v1.14.3/go.mod h1:mgiwOwqx65TmIk1wJ6Q7wvnVMocbUorkibMOrVTHZps= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= @@ -455,7 +459,6 @@ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -530,6 +533,8 @@ github.com/imdario/mergo v0.3.6 h1:xTNEAn+kxVO7dTZGu0CegyqKZmoWFI0rF8UxjlB2d28= github.com/imdario/mergo v0.3.6/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= github.com/influxdata/influxql v1.1.0 h1:sPsaumLFRPMwR5QtD3Up54HXpNND8Eu7G1vQFmi3quQ= github.com/influxdata/influxql v1.1.0/go.mod h1:KpVI7okXjK6PRi3Z5B+mtKZli+R1DnZgb3N+tzevNgo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= github.com/jaypipes/ghw v0.11.0 h1:i0pKvAM7eZk0KvLm9vzpcpDKTRnfR6AQ5pFkPVnYJXU= github.com/jaypipes/ghw v0.11.0/go.mod h1:jeJGbkRB2lL3/gxYzNYzEDETV1ZJ56OKr+CSeSEym+g= github.com/jdcloud-api/jdcloud-sdk-go v1.55.0 h1:mzVj8r6fluEwjn8ogqtGfYW2qSIVUaEq0JAsvjCav3A= @@ -541,6 +546,7 @@ github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHW github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= @@ -576,8 +582,8 @@ github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFB github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.0/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -602,6 +608,10 @@ github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2 github.com/ma314smith/signedxml v0.0.0-20210628192057-abc5b481ae1c h1:UPJygtyk491bJJ/DnRJFuzcq9Dl9NSeFrJ7VdiRzMxc= github.com/ma314smith/signedxml v0.0.0-20210628192057-abc5b481ae1c/go.mod h1:KEgVcb43+f5KFUH/x6Vd3NROG0AIL2CuKMrIqYsmx6E= github.com/mailru/easyjson v0.0.0-20160728113105-d5b7844b561a/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/mark3labs/mcp-go v0.39.1 h1:2oPxk7aDbQhouakkYyKl2T4hKFU1c6FDaubWyGyVE1k= +github.com/mark3labs/mcp-go v0.39.1/go.mod h1:T7tUa2jO6MavG+3P25Oy/jR7iCeJPHImCZHRymCn39g= github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.8/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= @@ -778,7 +788,6 @@ github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= -github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= @@ -822,6 +831,8 @@ github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasO github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.2.2/go.mod h1:9ZxEEn6pIJ8Rxe320qSDBk6AsU0r9pR7Q4OcevTdifk= +github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= +github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/pflag v0.0.0-20170130214245-9ff6c6923cff/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= @@ -894,12 +905,16 @@ github.com/willf/bitset v1.1.11-0.20200630133818-d5bec3311243/go.mod h1:RjeCKbqT github.com/willf/bloom v0.0.0-20170505221640-54e3b963ee16/go.mod h1:MmAltL9pDMNTrvUkxdg0k0q5I0suxmuwp3KbyrZLOZ8= github.com/willf/bloom v2.0.3+incompatible h1:QDacWdqcAUI1MPOwIQZRy9kOR7yxfyEmxX8Wdm2/JPA= github.com/willf/bloom v2.0.3+incompatible/go.mod h1:MmAltL9pDMNTrvUkxdg0k0q5I0suxmuwp3KbyrZLOZ8= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/xuri/efp v0.0.0-20220603152613-6918739fd470 h1:6932x8ltq1w4utjmfMPVj09jdMlkY0aiA6+Skbtl3/c= github.com/xuri/efp v0.0.0-20220603152613-6918739fd470/go.mod h1:ybY/Jr0T0GTCnYjKqmdwxyxn2BQf2RcQIIvex5QldPI= github.com/xuri/excelize/v2 v2.7.1 h1:gm8q0UCAyaTt3MEF5wWMjVdmthm2EHAWesGSKS9tdVI= github.com/xuri/excelize/v2 v2.7.1/go.mod h1:qc0+2j4TvAUrBw36ATtcTeC1VCM0fFdAXZOmcF4nTpY= github.com/xuri/nfp v0.0.0-20220409054826-5e722a1d9e22 h1:OAmKAfT06//esDdpi/DZ8Qsdt4+M5+ltca05dA5bG2M= github.com/xuri/nfp v0.0.0-20220409054826-5e722a1d9e22/go.mod h1:WwHg+CVyzlv/TX9xqBFXEZAuxOPxn2k1GNHwG41IIUQ= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -1409,8 +1424,8 @@ sigs.k8s.io/structured-merge-diff/v4 v4.0.1/go.mod h1:bJZC9H9iH24zzfZ/41RGcq60oK sigs.k8s.io/yaml v1.1.0/go.mod h1:UJmg0vDUVViEyp3mgSv9WPwZCDxu4rQW1olrI1uml+o= sigs.k8s.io/yaml v1.2.0 h1:kr/MCeFWJWTwyaHoR9c8EjH9OumOmoF9YGiZd7lFm/Q= sigs.k8s.io/yaml v1.2.0/go.mod h1:yfXDCHCao9+ENCvLSE62v9VSji2MKu5jeNfTrofGhJc= -yunion.io/x/cloudmux v0.3.10-0-alpha.1.0.20250912144144-d0d8cf049d7f h1:E17WmoLx6siAZLLzi1bho2mV006FFXo4q4l2CgwV0mQ= -yunion.io/x/cloudmux v0.3.10-0-alpha.1.0.20250912144144-d0d8cf049d7f/go.mod h1:7P/TJZk8o4JjhFnF1nGZcsPg+sIpMoV0dWPPuG6yGLg= +yunion.io/x/cloudmux v0.3.10-0-alpha.1.0.20250915054625-7251d9eeceec h1:GvDds+zC42TTTFoxui2/Y8mquJQKZ0ay858+/VabUlE= +yunion.io/x/cloudmux v0.3.10-0-alpha.1.0.20250915054625-7251d9eeceec/go.mod h1:7P/TJZk8o4JjhFnF1nGZcsPg+sIpMoV0dWPPuG6yGLg= yunion.io/x/executor v0.0.0-20250518005516-5402e9e0bed0 h1:msG4SiDSVU7CrXH06WuHlNEZXIooTcmNbfrIGHuIHBU= yunion.io/x/executor v0.0.0-20250518005516-5402e9e0bed0/go.mod h1:Uxuou9WQIeJXNpy7t2fPLL0BYLvLiMvGQwY7Qc6aSws= yunion.io/x/jsonutils v0.0.0-20190625054549-a964e1e8a051/go.mod h1:4N0/RVzsYL3kH3WE/H1BjUQdFiWu50JGCFQuuy+Z634= diff --git a/pkg/mcp-server/adapters/cloudpods_adapter.go b/pkg/mcp-server/adapters/cloudpods_adapter.go new file mode 100644 index 0000000000..1a1405eaac --- /dev/null +++ b/pkg/mcp-server/adapters/cloudpods_adapter.go @@ -0,0 +1,78 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapters + +import ( + "context" + + "yunion.io/x/onecloud/pkg/mcclient" + "yunion.io/x/onecloud/pkg/mcp-server/options" +) + +// CloudpodsAdapter 是与 Cloudpods API 交互的适配器,负责认证和资源管理 +type CloudpodsAdapter struct { + client *mcclient.Client + session *mcclient.ClientSession +} + +type CloudRegion struct { + RegionId string `json:"region_id"` +} + +// NewCloudpodsAdapter 创建一个新的 Cloudpods 适配器实例 +func NewCloudpodsAdapter() *CloudpodsAdapter { + + client := mcclient.NewClient( + options.Options.IdentityBaseURL, + options.Options.Timeout, + false, + true, + "", + "", + ) + + return &CloudpodsAdapter{ + client: client, + } +} + +// authenticate 实现 Cloudpods 的认证逻辑,例如获取访问令牌 +func (a *CloudpodsAdapter) authenticate(ak string, sk string) error { + if a.session != nil { + return nil + } + + token, err := a.client.AuthenticateByAccessKey(ak, sk, "") + if err != nil { + return err + } + + a.session = a.client.NewSession( + context.Background(), + "", + "", + "apigateway", + token, + ) + + return nil +} + +func (a *CloudpodsAdapter) getSession(ak string, sk string) (*mcclient.ClientSession, error) { + if err := a.authenticate(ak, sk); err != nil { + return nil, err + } + return a.session, nil +} diff --git a/pkg/mcp-server/adapters/doc.go b/pkg/mcp-server/adapters/doc.go new file mode 100644 index 0000000000..76a365a863 --- /dev/null +++ b/pkg/mcp-server/adapters/doc.go @@ -0,0 +1 @@ +package adapters // import "yunion.io/x/onecloud/pkg/mcp-server/adapters" diff --git a/pkg/mcp-server/adapters/resource_operation_adapter.go b/pkg/mcp-server/adapters/resource_operation_adapter.go new file mode 100644 index 0000000000..bbd580e014 --- /dev/null +++ b/pkg/mcp-server/adapters/resource_operation_adapter.go @@ -0,0 +1,635 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapters + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "yunion.io/x/jsonutils" + + "yunion.io/x/onecloud/pkg/mcclient/modules/compute" + "yunion.io/x/onecloud/pkg/mcclient/modules/monitor" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// StartServer 启动 Cloudpods 中的服务器 +func (a *CloudpodsAdapter) StartServer(ctx context.Context, serverId string, req models.ServerStartRequest, ak string, sk string) (*models.ServerOperationResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造启动参数 + params := jsonutils.NewDict() + + // 如果需要自动续费预付费实例,则设置相应参数 + if req.AutoPrepaid { + params.Set("auto_prepaid", jsonutils.NewBool(true)) + } + + // 如果指定了 QEMU 版本,则设置相应参数 + if req.QemuVersion != "" { + params.Set("qemu_version", jsonutils.NewString(req.QemuVersion)) + } + + // 调用 Cloudpods API 启动服务器 + result, err := compute.Servers.PerformAction(session, serverId, "start", params) + if err != nil { + return nil, fmt.Errorf("failed to start server: %w", err) + } + + // 构造响应数据 + response := &models.ServerOperationResponse{ + Operation: "start", + } + + // 尝试将结果解析到响应结构体中 + if err := result.Unmarshal(response); err != nil { + // 如果解析失败,则尝试获取任务 ID + taskId, _ := result.GetString("task_id") + response.TaskId = taskId + // 如果任务 ID 不为空,则认为操作成功 + response.Success = taskId != "" + } + + return response, nil +} + +// StopServer 停止 Cloudpods 中的服务器 +func (a *CloudpodsAdapter) StopServer(ctx context.Context, serverId string, req models.ServerStopRequest, ak string, sk string) (*models.ServerOperationResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造停止参数 + params := jsonutils.NewDict() + + // 如果需要强制停止,则设置相应参数 + if req.IsForce { + params.Set("is_force", jsonutils.NewBool(true)) + } + + // 如果需要停止计费,则设置相应参数 + if req.StopCharging { + params.Set("stop_charging", jsonutils.NewBool(true)) + } + + // 如果设置了超时时间,则设置相应参数 + if req.TimeoutSecs > 0 { + params.Set("timeout_secs", jsonutils.NewInt(req.TimeoutSecs)) + } + + // 调用 Cloudpods API 停止服务器 + result, err := compute.Servers.PerformAction(session, serverId, "stop", params) + if err != nil { + return nil, fmt.Errorf("failed to stop server: %w", err) + } + + // 构造响应数据 + response := &models.ServerOperationResponse{ + Operation: "stop", + } + + // 尝试将结果解析到响应结构体中 + if err := result.Unmarshal(response); err != nil { + // 如果解析失败,则尝试获取任务 ID + taskId, _ := result.GetString("task_id") + response.TaskId = taskId + // 如果任务 ID 不为空,则认为操作成功 + response.Success = taskId != "" + } + + return response, nil +} + +// RestartServer 重启 Cloudpods 中的服务器 +func (a *CloudpodsAdapter) RestartServer(ctx context.Context, serverId string, req models.ServerRestartRequest, ak string, sk string) (*models.ServerOperationResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造重启参数 + params := jsonutils.NewDict() + + // 如果需要强制重启,则设置相应参数 + if req.IsForce { + params.Set("is_force", jsonutils.NewBool(true)) + } + + // 调用 Cloudpods API 重启服务器 + result, err := compute.Servers.PerformAction(session, serverId, "restart", params) + if err != nil { + return nil, fmt.Errorf("failed to restart server: %w", err) + } + + // 构造响应数据 + response := &models.ServerOperationResponse{ + Operation: "restart", + } + + // 尝试将结果解析到响应结构体中 + if err := result.Unmarshal(response); err != nil { + // 如果解析失败,则尝试获取任务 ID + taskId, _ := result.GetString("task_id") + response.TaskId = taskId + // 如果任务 ID 不为空,则认为操作成功 + response.Success = taskId != "" + } + + return response, nil +} + +// ResetServerPassword 重置 Cloudpods 中服务器的密码 +func (a *CloudpodsAdapter) ResetServerPassword(ctx context.Context, serverId string, req models.ServerResetPasswordRequest, ak string, sk string) (*models.ServerOperationResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造密码重置参数 + params := jsonutils.NewDict() + // 设置新密码 + params.Set("password", jsonutils.NewString(req.Password)) + + if req.ResetPassword { + params.Set("reset_password", jsonutils.NewBool(true)) + } + + if req.AutoStart { + params.Set("auto_start", jsonutils.NewBool(true)) + } + + if req.Username != "" { + params.Set("username", jsonutils.NewString(req.Username)) + } + + // 调用 Cloudpods API 重置服务器密码 + result, err := compute.Servers.PerformAction(session, serverId, "reset-password", params) + if err != nil { + return nil, fmt.Errorf("failed to reset server password: %w", err) + } + + // 构造响应数据 + response := &models.ServerOperationResponse{ + Operation: "reset-password", + } + + // 尝试将结果解析到响应结构体中 + if err := result.Unmarshal(response); err != nil { + // 如果解析失败,则尝试获取任务 ID + taskId, _ := result.GetString("task_id") + response.TaskId = taskId + // 如果任务 ID 不为空,则认为操作成功 + response.Success = taskId != "" + } + + return response, nil +} + +// DeleteServer 删除 Cloudpods 中的服务器 +func (a *CloudpodsAdapter) DeleteServer(ctx context.Context, serverId string, req models.ServerDeleteRequest, ak string, sk string) (*models.ServerOperationResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造删除参数 + params := jsonutils.NewDict() + // 如果需要覆盖待删除状态,则设置相应参数 + if req.OverridePendingDelete { + params.Set("override_pending_delete", jsonutils.NewBool(true)) + } + // 如果需要彻底删除,则设置相应参数 + if req.Purge { + params.Set("purge", jsonutils.NewBool(true)) + } + // 如果需要删除快照,则设置相应参数 + if req.DeleteSnapshots { + params.Set("delete_snapshots", jsonutils.NewBool(true)) + } + // 如果需要删除弹性 IP,则设置相应参数 + if req.DeleteEip { + params.Set("delete_eip", jsonutils.NewBool(true)) + } + // 如果需要删除磁盘,则设置相应参数 + if req.DeleteDisks { + params.Set("delete_disks", jsonutils.NewBool(true)) + } + + // 调用 Cloudpods API 删除服务器 + result, err := compute.Servers.Delete(session, serverId, params) + if err != nil { + return nil, fmt.Errorf("failed to delete server: %w", err) + } + + // 构造响应数据 + response := &models.ServerOperationResponse{ + Operation: "delete", + } + + // 尝试将结果解析到响应结构体中 + if err := result.Unmarshal(response); err != nil { + // 如果解析失败,则尝试获取任务 ID + taskId, _ := result.GetString("task_id") + response.TaskId = taskId + // 如果任务 ID 不为空,则认为操作成功 + response.Success = taskId != "" + } + + return response, nil +} + +// CreateServer 在 Cloudpods 中创建服务器 +func (a *CloudpodsAdapter) CreateServer(ctx context.Context, req models.CreateServerRequest, ak string, sk string) (*models.CreateServerResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造创建服务器的参数 + params := jsonutils.NewDict() + // 设置服务器名称 + params.Set("name", jsonutils.NewString(req.Name)) + // 设置 CPU 核心数 + params.Set("vcpu_count", jsonutils.NewInt(req.VcpuCount)) + // 设置内存大小 + params.Set("vmem_size", jsonutils.NewInt(req.VmemSize)) + + // 如果创建数量大于1,则设置相应参数 + if req.Count > 1 { + params.Set("count", jsonutils.NewInt(int64(req.Count))) + } + + // 如果需要自动启动,则设置相应参数 + if req.AutoStart { + params.Set("auto_start", jsonutils.NewBool(req.AutoStart)) + } + + // 如果设置了密码,则设置相应参数 + if req.Password != "" { + params.Set("password", jsonutils.NewString(req.Password)) + } + + // 如果设置了计费类型,则设置相应参数 + if req.BillingType != "" { + params.Set("billing_type", jsonutils.NewString(req.BillingType)) + } + + // 如果设置了计费时长,则设置相应参数 + if req.Duration != "" { + params.Set("duration", jsonutils.NewString(req.Duration)) + } + + // 如果设置了描述,则设置相应参数 + if req.Description != "" { + params.Set("description", jsonutils.NewString(req.Description)) + } + + // 如果设置了主机名,则设置相应参数 + if req.Hostname != "" { + params.Set("hostname", jsonutils.NewString(req.Hostname)) + } + + // 如果设置了虚拟化类型,则设置相应参数 + if req.Hypervisor != "" { + params.Set("hypervisor", jsonutils.NewString(req.Hypervisor)) + } + + // 如果设置了用户数据,则设置相应参数 + if req.UserData != "" { + params.Set("user_data", jsonutils.NewString(req.UserData)) + } + + // 如果设置了密钥对 ID,则设置相应参数 + if req.KeypairId != "" { + params.Set("keypair_id", jsonutils.NewString(req.KeypairId)) + } + + // 如果设置了项目 ID,则设置相应参数 + if req.ProjectId != "" { + params.Set("project_id", jsonutils.NewString(req.ProjectId)) + } + + // 如果设置了可用区 ID,则设置相应参数 + if req.ZoneId != "" { + params.Set("prefer_zone_id", jsonutils.NewString(req.ZoneId)) + } + + // 如果设置了区域 ID,则设置相应参数 + if req.RegionId != "" { + params.Set("prefer_region_id", jsonutils.NewString(req.RegionId)) + } + + // 如果需要禁用删除,则设置相应参数 + if req.DisableDelete { + params.Set("disable_delete", jsonutils.NewBool(req.DisableDelete)) + } + + // 如果设置了启动顺序,则设置相应参数 + if req.BootOrder != "" { + params.Set("boot_order", jsonutils.NewString(req.BootOrder)) + } + + // 如果设置了元数据,则设置相应参数 + if len(req.Metadata) > 0 { + metaDict := jsonutils.NewDict() + for k, v := range req.Metadata { + metaDict.Set(k, jsonutils.NewString(v)) + } + params.Set("__meta__", metaDict) + } + + // 构造磁盘参数 + disks := jsonutils.NewArray() + + // 如果设置了镜像 ID,则构造系统磁盘参数 + if req.ImageId != "" { + diskDict := jsonutils.NewDict() + diskDict.Set("image_id", jsonutils.NewString(req.ImageId)) + diskDict.Set("disk_type", jsonutils.NewString("sys")) + if req.DiskSize > 0 { + diskDict.Set("size", jsonutils.NewInt(req.DiskSize)) + } + disks.Add(diskDict) + } + + // 构造数据磁盘参数 + for _, disk := range req.DataDisks { + diskDict := jsonutils.NewDict() + if disk.ImageId != "" { + diskDict.Set("image_id", jsonutils.NewString(disk.ImageId)) + } + if disk.Size > 0 { + diskDict.Set("size", jsonutils.NewInt(disk.Size)) + } + diskDict.Set("disk_type", jsonutils.NewString(disk.DiskType)) + disks.Add(diskDict) + } + + // 如果有磁盘参数,则设置相应参数 + if disks.Length() > 0 { + params.Set("disks", disks) + } + + // 如果设置了网络 ID,则构造网络参数 + if req.NetworkId != "" { + networks := jsonutils.NewArray() + netDict := jsonutils.NewDict() + netDict.Set("network", jsonutils.NewString(req.NetworkId)) + networks.Add(netDict) + params.Set("nets", networks) + } + + // 如果设置了安全组 ID,则设置相应参数 + if req.SecgroupId != "" { + params.Set("secgrp_id", jsonutils.NewString(req.SecgroupId)) + } + + // 如果设置了安全组列表,则设置相应参数 + if len(req.Secgroups) > 0 { + secgroups := jsonutils.NewArray() + for _, sg := range req.Secgroups { + secgroups.Add(jsonutils.NewString(sg)) + } + params.Set("secgroups", secgroups) + } + + // 如果设置了服务器规格 ID,则设置相应参数 + if req.ServerskuId != "" { + params.Set("instance_type", jsonutils.NewString(req.ServerskuId)) + } + + // 调用 Cloudpods API 创建服务器 + result, err := compute.Servers.Create(session, params) + if err != nil { + return nil, fmt.Errorf("failed to create server: %w", err) + } + + // 构造响应数据 + response := &models.CreateServerResponse{} + if err := result.Unmarshal(response); err != nil { + return nil, fmt.Errorf("failed to unmarshal create server response: %w", err) + } + + return response, nil +} + +// GetServerMonitor 获取 Cloudpods 中服务器的监控数据 +func (a *CloudpodsAdapter) GetServerMonitor(ctx context.Context, serverId string, startTime, endTime int64, metrics []string, ak string, sk string) (*models.MonitorResponse, error) { + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + params := jsonutils.NewDict() + + metricQuery := jsonutils.NewArray() + + for _, metric := range metrics { + + modelDict := jsonutils.NewDict() + + modelDict.Set("database", jsonutils.NewString("telegraf")) + modelDict.Set("measurement", jsonutils.NewString("vm_cpu")) + + switch metric { + case "cpu_usage": + modelDict.Set("measurement", jsonutils.NewString("vm_cpu")) + case "mem_usage": + modelDict.Set("measurement", jsonutils.NewString("vm_mem")) + case "disk_usage": + modelDict.Set("measurement", jsonutils.NewString("vm_disk")) + case "net_bps_rx", "net_bps_tx": + modelDict.Set("measurement", jsonutils.NewString("vm_netio")) + } + + tagsArray := jsonutils.NewArray() + tagDict := jsonutils.NewDict() + tagDict.Set("key", jsonutils.NewString("vm_id")) + tagDict.Set("operator", jsonutils.NewString("=")) + tagDict.Set("value", jsonutils.NewString(serverId)) + tagsArray.Add(tagDict) + modelDict.Set("tags", tagsArray) + + queryDict := jsonutils.NewDict() + queryDict.Set("model", modelDict) + + if startTime > 0 { + queryDict.Set("from", jsonutils.NewString(fmt.Sprintf("%d", startTime))) + } + if endTime > 0 { + queryDict.Set("to", jsonutils.NewString(fmt.Sprintf("%d", endTime))) + } + + metricQuery.Add(queryDict) + } + + params.Set("metric_query", metricQuery) + params.Set("scope", jsonutils.NewString("system")) + + params.Set("interval", jsonutils.NewString("60s")) + + result, err := monitor.UnifiedMonitorManager.PerformAction(session, "query", "", params) + if err != nil { + return nil, fmt.Errorf("failed to get server monitor data: %w", err) + } + + response := &models.MonitorResponse{ + Status: 200, + Data: models.MonitorResponseData{ + Metrics: []models.MetricData{}, + }, + } + + unifiedmonitor, err := result.Get("unifiedmonitor") + if err != nil { + return nil, fmt.Errorf("failed to get unifiedmonitor data: %w", err) + } + + series, err := unifiedmonitor.Get("Series") + if err != nil { + return nil, fmt.Errorf("failed to get series data: %w", err) + } + + seriesArray, ok := series.(*jsonutils.JSONArray) + if !ok { + return nil, fmt.Errorf("invalid series data format") + } + + for i := 0; i < seriesArray.Length(); i++ { + seriesObj, err := seriesArray.GetAt(i) + if err != nil { + continue + } + + name, _ := seriesObj.GetString("name") + + metricData := models.MetricData{ + Metric: name, + Unit: "%", + Values: []models.MetricValue{}, + } + + if strings.Contains(name, "net_bps") { + metricData.Unit = "bps" + } else if strings.Contains(name, "disk_io") { + metricData.Unit = "iops" + } + + points, err := seriesObj.Get("points") + if err != nil { + continue + } + + pointsArray, ok := points.(*jsonutils.JSONArray) + if !ok { + continue + } + + for j := 0; j < pointsArray.Length(); j++ { + pointObj, err := pointsArray.GetAt(j) + if err != nil { + continue + } + + pointArray, ok := pointObj.(*jsonutils.JSONArray) + if !ok || pointArray.Length() < 2 { + continue + } + + timestamp, err := pointArray.GetAt(0) + if err != nil { + continue + } + + value, err := pointArray.GetAt(1) + if err != nil { + continue + } + + timestampStr, _ := timestamp.GetString() + valueStr, _ := value.GetString() + + timestampInt, _ := strconv.ParseInt(timestampStr, 10, 64) + valueFloat, _ := strconv.ParseFloat(valueStr, 64) + + metricData.Values = append(metricData.Values, models.MetricValue{ + Timestamp: timestampInt, + Value: valueFloat, + }) + } + + response.Data.Metrics = append(response.Data.Metrics, metricData) + } + + return response, nil +} + +// GetServerStats 获取 Cloudpods 中服务器的实时统计数据 +func (a *CloudpodsAdapter) GetServerStats(ctx context.Context, serverId string, ak string, sk string) (*models.ServerStatsResponse, error) { + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + params := jsonutils.NewDict() + result, err := compute.Servers.GetSpecific(session, serverId, "stats", params) + if err != nil { + return nil, fmt.Errorf("failed to get server stats: %w", err) + } + + statsData := models.ServerStatsData{} + + cpuUsed, _ := result.Float("cpu_used") + statsData.CPUUsage = cpuUsed * 100 + + memSize, _ := result.Int("mem_size") + memUsed, _ := result.Int("mem_used") + if memSize > 0 { + statsData.MemUsage = float64(memUsed) / float64(memSize) * 100 + } + + diskSize, _ := result.Int("disk_size") + diskUsed, _ := result.Int("disk_used") + if diskSize > 0 { + statsData.DiskUsage = float64(diskUsed) / float64(diskSize) * 100 + } + + netInRate, _ := result.Float("net_in_rate") + netOutRate, _ := result.Float("net_out_rate") + statsData.NetBpsRx = int64(netInRate) + statsData.NetBpsTx = int64(netOutRate) + + statsData.UpdatedAt = time.Now().Format("2006-01-02 15:04:05") + + response := &models.ServerStatsResponse{ + Status: 200, + Data: statsData, + } + + return response, nil +} diff --git a/pkg/mcp-server/adapters/resource_query_adapter.go b/pkg/mcp-server/adapters/resource_query_adapter.go new file mode 100644 index 0000000000..3a13424147 --- /dev/null +++ b/pkg/mcp-server/adapters/resource_query_adapter.go @@ -0,0 +1,493 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package adapters + +import ( + "context" + "fmt" + + "yunion.io/x/jsonutils" + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcclient/modules/compute" + "yunion.io/x/onecloud/pkg/mcclient/modules/image" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// ListCloudRegions 查询 Cloudpods 中的区域列表 +func (a CloudpodsAdapter) ListCloudRegions(ctx context.Context, limit int, offset int, search string, provider string, ak string, sk string) (*models.CloudregionListResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造查询参数 + params := jsonutils.NewDict() + if limit > 0 { + // 设置查询结果数量限制 + params.Set("limit", jsonutils.NewInt(int64(limit))) + } + if offset > 0 { + // 设置查询偏移量 + params.Set("offset", jsonutils.NewInt(int64(offset))) + } + if search != "" { + // 设置搜索关键字 + params.Set("search", jsonutils.NewString(search)) + } + if provider != "" { + // 设置云提供商过滤条件 + providers := jsonutils.NewArray() + providers.Add(jsonutils.NewString(provider)) + params.Set("providers", providers) + } + + // 调用 Cloudpods API 查询区域列表 + result, err := compute.Cloudregions.List(session, params) + if err != nil { + return nil, err + } + + // 构造响应数据 + response := &models.CloudregionListResponse{ + Limit: int64(limit), + Offset: int64(offset), + Cloudregions: make([]models.CloudregionDetails, 0), + Total: int64(result.Total), + } + // 遍历查询结果,将数据转换为响应格式 + for _, data := range result.Data { + region := models.CloudregionDetails{} + if err := data.Unmarshal(®ion); err != nil { + // 如果数据转换失败,记录警告日志并跳过该条数据 + log.Warningf("Failed to unmarshal cloudregion details: %s", err) + continue + } + response.Cloudregions = append(response.Cloudregions, region) + } + + return response, nil +} + +// ListVPCs 查询 Cloudpods 中的 VPC 列表 +func (a *CloudpodsAdapter) ListVPCs(limit int, offset int, search string, cloudregionId string, ak string, sk string) (*models.VpcListResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造查询参数 + params := jsonutils.NewDict() + if limit > 0 { + // 设置查询结果数量限制 + params.Set("limit", jsonutils.NewInt(int64(limit))) + } + if offset > 0 { + // 设置查询偏移量 + params.Set("offset", jsonutils.NewInt(int64(offset))) + } + if search != "" { + // 设置搜索关键字 + params.Set("search", jsonutils.NewString(search)) + } + if cloudregionId != "" { + // 设置云区域 ID 过滤条件 + cloudregionIds := jsonutils.NewArray() + cloudregionIds.Add(jsonutils.NewString(cloudregionId)) + params.Set("cloudregion_id", cloudregionIds) + } + + // 调用 Cloudpods API 查询 VPC 列表 + result, err := compute.Vpcs.List(session, params) + if err != nil { + return nil, fmt.Errorf("failed to list vpcs: %w", err) + } + + // 构造响应数据 + response := &models.VpcListResponse{ + Limit: int64(limit), + Offset: int64(offset), + Vpcs: make([]models.VpcDetails, 0), + Total: int64(result.Total), + } + // 遍历查询结果,将数据转换为响应格式 + for _, data := range result.Data { + vpc := models.VpcDetails{} + if err := data.Unmarshal(&vpc); err != nil { + // 如果数据转换失败,记录警告日志并跳过该条数据 + log.Warningf("Failed to unmarshal vpc details: %s", err) + continue + } + response.Vpcs = append(response.Vpcs, vpc) + } + + return response, nil +} + +// ListNetworks 查询 Cloudpods 中的网络列表 +func (a *CloudpodsAdapter) ListNetworks(limit int, offset int, search string, vpcId string, ak string, sk string) (*models.NetworkListResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造查询参数 + params := jsonutils.NewDict() + if limit > 0 { + // 设置查询结果数量限制 + params.Set("limit", jsonutils.NewInt(int64(limit))) + } + if offset > 0 { + // 设置查询偏移量 + params.Set("offset", jsonutils.NewInt(int64(offset))) + } + if search != "" { + // 设置搜索关键字 + params.Set("search", jsonutils.NewString(search)) + } + if vpcId != "" { + // 设置 VPC ID 过滤条件 + //vpcIds := jsonutils.NewArray() + //vpcIds.Add(jsonutils.NewString(vpcId)) + //params.Set("vpc_id", vpcIds) + params.Set("vpc_id", jsonutils.NewString(vpcId)) + } + + // 调用 Cloudpods API 查询网络列表 + result, err := compute.Networks.List(session, params) + if err != nil { + return nil, fmt.Errorf("failed to list networks: %w", err) + } + + // 构造响应数据 + response := &models.NetworkListResponse{ + Limit: int64(limit), + Offset: int64(offset), + Networks: make([]models.NetworkDetails, 0), + Total: int64(result.Total), + } + // 遍历查询结果,将数据转换为响应格式 + for _, data := range result.Data { + network := models.NetworkDetails{} + if err := data.Unmarshal(&network); err != nil { + // 如果数据转换失败,记录警告日志并跳过该条数据 + log.Warningf("Failed to unmarshal network details: %s", err) + continue + } + response.Networks = append(response.Networks, network) + } + + return response, nil +} + +// ListImages 查询 Cloudpods 中的镜像列表 +func (a *CloudpodsAdapter) ListImages(limit int, offset int, search string, osTypes []string, ak string, sk string) (*models.ImageListResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造查询参数 + params := jsonutils.NewDict() + if limit > 0 { + // 设置查询结果数量限制 + params.Set("limit", jsonutils.NewInt(int64(limit))) + } + if offset > 0 { + // 设置查询偏移量 + params.Set("offset", jsonutils.NewInt(int64(offset))) + } + if search != "" { + // 设置搜索关键字 + params.Set("search", jsonutils.NewString(search)) + } + if len(osTypes) > 0 { + // 设置操作系统类型过滤条件 + osTypesArray := jsonutils.NewArray() + for _, osType := range osTypes { + osTypesArray.Add(jsonutils.NewString(osType)) + } + params.Set("os_types", osTypesArray) + } + + // 调用 Cloudpods API 查询镜像列表 + result, err := image.Images.List(session, params) + if err != nil { + return nil, fmt.Errorf("failed to list images: %w", err) + } + + // 构造响应数据 + response := &models.ImageListResponse{ + Limit: int64(limit), + Offset: int64(offset), + Images: make([]models.ImageDetails, 0), + Total: int64(result.Total), + } + // 遍历查询结果,将数据转换为响应格式 + for _, data := range result.Data { + image := models.ImageDetails{} + if err := data.Unmarshal(&image); err != nil { + // 如果数据转换失败,记录警告日志并跳过该条数据 + log.Warningf("Failed to unmarshal image details: %s", err) + continue + } + response.Images = append(response.Images, image) + } + + return response, nil +} + +// ListServerSkus 查询 Cloudpods 中的服务器规格列表 +func (a *CloudpodsAdapter) ListServerSkus(limit int, offset int, search string, cloudregionIds []string, zoneIds []string, cpuCoreCount []string, memorySizeMB []string, providers []string, cpuArch []string, ak string, sk string) (*models.ServerSkuListResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造查询参数 + params := jsonutils.NewDict() + if limit > 0 { + // 设置查询结果数量限制 + params.Set("limit", jsonutils.NewInt(int64(limit))) + } + if offset > 0 { + // 设置查询偏移量 + params.Set("offset", jsonutils.NewInt(int64(offset))) + } + if search != "" { + // 设置搜索关键字 + params.Set("search", jsonutils.NewString(search)) + } + if len(cloudregionIds) > 0 { + // 设置云区域 ID 过滤条件 + cloudregionIdArray := jsonutils.NewArray() + for _, id := range cloudregionIds { + cloudregionIdArray.Add(jsonutils.NewString(id)) + } + params.Set("cloudregion_id", cloudregionIdArray) + } + if len(zoneIds) > 0 { + // 设置可用区 ID 过滤条件 + zoneIdArray := jsonutils.NewArray() + for _, id := range zoneIds { + zoneIdArray.Add(jsonutils.NewString(id)) + } + params.Set("zone_ids", zoneIdArray) + } + if len(cpuCoreCount) > 0 { + // 设置 CPU 核心数过滤条件 + cpuCoreArray := jsonutils.NewArray() + for _, count := range cpuCoreCount { + cpuCoreArray.Add(jsonutils.NewString(count)) + } + params.Set("cpu_core_count", cpuCoreArray) + } + if len(memorySizeMB) > 0 { + // 设置内存大小过滤条件 + memoryArray := jsonutils.NewArray() + for _, size := range memorySizeMB { + memoryArray.Add(jsonutils.NewString(size)) + } + params.Set("memory_size_mb", memoryArray) + } + if len(providers) > 0 { + // 设置提供商过滤条件 + providerArray := jsonutils.NewArray() + for _, provider := range providers { + providerArray.Add(jsonutils.NewString(provider)) + } + params.Set("providers", providerArray) + } + if len(cpuArch) > 0 { + // 设置 CPU 架构过滤条件 + cpuArchArray := jsonutils.NewArray() + for _, arch := range cpuArch { + cpuArchArray.Add(jsonutils.NewString(arch)) + } + params.Set("cpu_arch", cpuArchArray) + } + + // 调用 Cloudpods API 查询服务器规格列表 + result, err := compute.ServerSkus.List(session, params) + if err != nil { + return nil, fmt.Errorf("failed to list server skus: %w", err) + } + + // 构造响应数据 + response := &models.ServerSkuListResponse{ + Limit: int64(limit), + Offset: int64(offset), + Serverskus: make([]models.ServerSkuDetails, 0), + Total: int64(result.Total), + } + // 遍历查询结果,将数据转换为响应格式 + for _, data := range result.Data { + sku := models.ServerSkuDetails{} + if err := data.Unmarshal(&sku); err != nil { + // 如果数据转换失败,记录警告日志并跳过该条数据 + log.Warningf("Failed to unmarshal server sku details: %s", err) + continue + } + response.Serverskus = append(response.Serverskus, sku) + } + + return response, nil +} + +// ListStorages 查询 Cloudpods 中的存储列表 +func (a *CloudpodsAdapter) ListStorages(limit int, offset int, search string, cloudregionIds []string, zoneIds []string, providers []string, storageTypes []string, hostId string, ak string, sk string) (*models.StorageListResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造查询参数 + params := jsonutils.NewDict() + if limit > 0 { + // 设置查询结果数量限制 + params.Set("limit", jsonutils.NewInt(int64(limit))) + } + if offset > 0 { + // 设置查询偏移量 + params.Set("offset", jsonutils.NewInt(int64(offset))) + } + if search != "" { + // 设置搜索关键字 + params.Set("search", jsonutils.NewString(search)) + } + if len(cloudregionIds) > 0 { + // 设置云区域 ID 过滤条件 + cloudregionIdArray := jsonutils.NewArray() + for _, id := range cloudregionIds { + cloudregionIdArray.Add(jsonutils.NewString(id)) + } + params.Set("cloudregion_id", cloudregionIdArray) + } + if len(zoneIds) > 0 { + // 设置可用区 ID 过滤条件 + zoneIdArray := jsonutils.NewArray() + for _, id := range zoneIds { + zoneIdArray.Add(jsonutils.NewString(id)) + } + params.Set("zone_ids", zoneIdArray) + } + if len(providers) > 0 { + // 设置提供商过滤条件 + providerArray := jsonutils.NewArray() + for _, provider := range providers { + providerArray.Add(jsonutils.NewString(provider)) + } + params.Set("providers", providerArray) + } + if len(storageTypes) > 0 { + // 设置存储类型过滤条件 + for _, storageType := range storageTypes { + params.Set("storage_type", jsonutils.NewString(storageType)) + break + } + } + if hostId != "" { + // 设置主机 ID 过滤条件 + params.Set("host_id", jsonutils.NewString(hostId)) + } + + // 调用 Cloudpods API 查询存储列表 + result, err := compute.Storages.List(session, params) + if err != nil { + return nil, fmt.Errorf("failed to list storages: %w", err) + } + + // 构造响应数据 + response := &models.StorageListResponse{ + Limit: int64(limit), + Offset: int64(offset), + Storages: make([]models.StorageDetails, 0), + Total: int64(result.Total), + } + + // 遍历查询结果,将数据转换为响应格式 + for _, data := range result.Data { + storage := models.StorageDetails{} + if err := data.Unmarshal(&storage); err != nil { + // 如果数据转换失败,记录警告日志并跳过该条数据 + log.Warningf("Failed to unmarshal storage details: %s", err) + continue + } + response.Storages = append(response.Storages, storage) + } + + return response, nil +} + +// ListServers 查询 Cloudpods 中的服务器列表 +func (a *CloudpodsAdapter) ListServers(ctx context.Context, limit int, offset int, search string, status string, ak string, sk string) (*models.ServerListResponse, error) { + // 获取 Cloudpods 会话 + session, err := a.getSession(ak, sk) + if err != nil { + return nil, err + } + + // 构造查询参数 + params := jsonutils.NewDict() + if limit > 0 { + // 设置查询结果数量限制 + params.Set("limit", jsonutils.NewInt(int64(limit))) + } + if offset > 0 { + // 设置查询偏移量 + params.Set("offset", jsonutils.NewInt(int64(offset))) + } + if search != "" { + // 设置搜索关键字 + params.Set("search", jsonutils.NewString(search)) + } + if status != "" { + // 设置服务器状态过滤条件 + params.Set("status", jsonutils.NewString(status)) + } + + // 调用 Cloudpods API 查询服务器列表 + result, err := compute.Servers.List(session, params) + if err != nil { + return nil, fmt.Errorf("failed to list servers: %w", err) + } + + // 构造响应数据 + response := &models.ServerListResponse{ + Limit: int64(limit), + Offset: int64(offset), + Servers: make([]models.ServerDetails, 0), + Total: int64(result.Total), + } + + // 遍历查询结果,将数据转换为响应格式 + for _, data := range result.Data { + server := models.ServerDetails{} + if err := data.Unmarshal(&server); err != nil { + // 如果数据转换失败,记录警告日志并跳过该条数据 + log.Warningf("Failed to unmarshal server details: %s", err) + continue + } + response.Servers = append(response.Servers, server) + } + + return response, nil +} diff --git a/pkg/mcp-server/models/doc.go b/pkg/mcp-server/models/doc.go new file mode 100644 index 0000000000..f6fb292fe7 --- /dev/null +++ b/pkg/mcp-server/models/doc.go @@ -0,0 +1 @@ +package models // import "yunion.io/x/onecloud/pkg/mcp-server/models" diff --git a/pkg/mcp-server/models/models.go b/pkg/mcp-server/models/models.go new file mode 100644 index 0000000000..98a2604eee --- /dev/null +++ b/pkg/mcp-server/models/models.go @@ -0,0 +1,774 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package models + +import "time" + +type ListRegionsReq struct { +} + +type CloudregionDetails struct { + CanDelete bool `json:"can_delete"` + CanUpdate bool `json:"can_update"` + City string `json:"city"` + CloudEnv string `json:"cloud_env"` + CountryCode string `json:"country_code"` + CreatedAt *time.Time `json:"created_at"` + Deleted bool `json:"deleted"` + DeletedAt *time.Time `json:"deleted_at"` + Description string `json:"description"` + Enabled bool `json:"enabled"` + Environment string `json:"environment"` + ExternalId string `json:"external_id"` + GuestCount int64 `json:"guest_count"` + GuestIncrementCount int64 `json:"guest_increment_count"` + Id string `json:"id"` + ImportedAt *time.Time `json:"imported_at"` + IsEmulated bool `json:"is_emulated"` + Latitude float64 `json:"latitude"` + Longitude float64 `json:"longitude"` + Metadata map[string]string `json:"metadata"` + Name string `json:"name"` + NetworkCount int64 `json:"network_count"` + Progress float64 `json:"progress"` + Provider string `json:"provider"` + Source string `json:"source"` + Status string `json:"status"` + UpdateVersion int64 `json:"update_version"` + UpdatedAt *time.Time `json:"updated_at"` + VpcCount int64 `json:"vpc_count"` + ZoneCount int64 `json:"zone_count"` +} + +type CloudregionListResponse struct { + Limit int64 `json:"limit"` + Offset int64 `json:"offset"` + Cloudregions []CloudregionDetails `json:"cloudregions"` + Total int64 `json:"total"` +} + +type SharedDomain struct { + Id string `json:"id"` + Name string `json:"name"` +} + +type SharedProject struct { + Domain string `json:"domain"` + DomainId string `json:"domain_id"` + Id string `json:"id"` + Name string `json:"name"` +} + +type VpcDetails struct { + Account string `json:"account"` + AccountHealthStatus string `json:"account_health_status"` + AccountId string `json:"account_id"` + AccountReadOnly bool `json:"account_read_only"` + AccountStatus string `json:"account_status"` + AcceptVpcPeerCount int64 `json:"accpet_vpc_peer_count"` + Brand string `json:"brand"` + CanDelete bool `json:"can_delete"` + CanUpdate bool `json:"can_update"` + CidrBlock string `json:"cidr_block"` + CidrBlock6 string `json:"cidr_block6"` + CloudEnv string `json:"cloud_env"` + Cloudregion string `json:"cloudregion"` + CloudregionId string `json:"cloudregion_id"` + CreatedAt *time.Time `json:"created_at"` + Deleted bool `json:"deleted"` + DeletedAt *time.Time `json:"deleted_at"` + Description string `json:"description"` + Direct bool `json:"direct"` + DnsZoneCount int64 `json:"dns_zone_count"` + DomainId string `json:"domain_id"` + DomainSrc string `json:"domain_src"` + Enabled bool `json:"enabled"` + Environment string `json:"environment"` + ExternalAccessMode string `json:"external_access_mode"` + ExternalId string `json:"external_id"` + Globalvpc string `json:"globalvpc"` + GlobalvpcId string `json:"globalvpc_id"` + Id string `json:"id"` + ImportedAt *time.Time `json:"imported_at"` + IsDefault bool `json:"is_default"` + IsEmulated bool `json:"is_emulated"` + IsPublic bool `json:"is_public"` + Manager string `json:"manager"` + ManagerDomain string `json:"manager_domain"` + ManagerDomainId string `json:"manager_domain_id"` + ManagerId string `json:"manager_id"` + ManagerProject string `json:"manager_project"` + ManagerProjectId string `json:"manager_project_id"` + Metadata map[string]string `json:"metadata"` + Name string `json:"name"` + NatgatewayCount int64 `json:"natgateway_count"` + NetworkCount int64 `json:"network_count"` + Progress float64 `json:"progress"` + ProjectDomain string `json:"project_domain"` + Provider string `json:"provider"` + PublicScope string `json:"public_scope"` + PublicSrc string `json:"public_src"` + Region string `json:"region"` + RegionExtId string `json:"region_ext_id"` + RegionExternalId string `json:"region_external_id"` + RegionId string `json:"region_id"` + RequestVpcPeerCount int64 `json:"request_vpc_peer_count"` + RoutetableCount int64 `json:"routetable_count"` + SharedDomains []SharedDomain `json:"shared_domains"` + SharedProjects []SharedProject `json:"shared_projects"` + Source string `json:"source"` + Status string `json:"status"` + UpdateVersion int64 `json:"update_version"` + UpdatedAt *time.Time `json:"updated_at"` + WireCount int64 `json:"wire_count"` +} + +type VpcListResponse struct { + Limit int64 `json:"limit"` + Offset int64 `json:"offset"` + Vpcs []VpcDetails `json:"vpcs"` + Total int64 `json:"total"` +} + +type SchedtagShortDescDetails struct { + Default string `json:"default"` + Id string `json:"id"` + Name string `json:"name"` + ResName string `json:"res_name"` +} + +type SRoute []string + +type SSimpleWire struct { + Wire string `json:"Wire"` + WireId string `json:"WireId"` +} + +type NetworkDetails struct { + Account string `json:"account"` + AccountHealthStatus string `json:"account_health_status"` + AccountId string `json:"account_id"` + AccountReadOnly bool `json:"account_read_only"` + AccountStatus string `json:"account_status"` + AdditionalWires []SSimpleWire `json:"additional_wires"` + AllocPolicy string `json:"alloc_policy"` + AllocTimoutSeconds int64 `json:"alloc_timout_seconds"` + BgpType string `json:"bgp_type"` + BmReusedVnics int64 `json:"bm_reused_vnics"` + BmVnics int64 `json:"bm_vnics"` + Brand string `json:"brand"` + CanDelete bool `json:"can_delete"` + CanUpdate bool `json:"can_update"` + CloudEnv string `json:"cloud_env"` + Cloudregion string `json:"cloudregion"` + CloudregionId string `json:"cloudregion_id"` + CreatedAt *time.Time `json:"created_at"` + Deleted bool `json:"deleted"` + DeletedAt *time.Time `json:"deleted_at"` + Description string `json:"description"` + Dns string `json:"dns"` + DomainId string `json:"domain_id"` + EipVnics int64 `json:"eip_vnics"` + Environment string `json:"environment"` + Exit bool `json:"exit"` + ExternalId string `json:"external_id"` + Freezed bool `json:"freezed"` + GroupVnics int64 `json:"group_vnics"` + GuestDhcp string `json:"guest_dhcp"` + GuestDns string `json:"guest_dns"` + GuestDns6 string `json:"guest_dns6"` + GuestDomain string `json:"guest_domain"` + GuestDomain6 string `json:"guest_domain6"` + GuestGateway string `json:"guest_gateway"` + GuestGateway6 string `json:"guest_gateway6"` + GuestIpEnd string `json:"guest_ip_end"` + GuestIpMask uint8 `json:"guest_ip_mask"` + GuestIpStart string `json:"guest_ip_start"` + GuestIp6End string `json:"guest_ip6_end"` + GuestIp6Mask uint8 `json:"guest_ip6_mask"` + GuestIp6Start string `json:"guest_ip6_start"` + GuestNtp string `json:"guest_ntp"` + Id string `json:"id"` + IfnameHint string `json:"ifname_hint"` + ImportedAt *time.Time `json:"imported_at"` + IsAutoAlloc bool `json:"is_auto_alloc"` + IsClassic bool `json:"is_classic"` + IsDefaultVpc bool `json:"is_default_vpc"` + IsEmulated bool `json:"is_emulated"` + IsPublic bool `json:"is_public"` + IsSystem bool `json:"is_system"` + LbVnics int64 `json:"lb_vnics"` + Manager string `json:"manager"` + ManagerDomain string `json:"manager_domain"` + ManagerDomainId string `json:"manager_domain_id"` + ManagerId string `json:"manager_id"` + ManagerProject string `json:"manager_project"` + ManagerProjectId string `json:"manager_project_id"` + Metadata map[string]string `json:"metadata"` + Name string `json:"name"` + NatVnics int64 `json:"nat_vnics"` + NetworkinterfaceVnics int64 `json:"networkinterface_vnics"` + PendingDeleted bool `json:"pending_deleted"` + PendingDeletedAt *time.Time `json:"pending_deleted_at"` + Ports int64 `json:"ports"` + PortsUsed int64 `json:"ports_used"` + Ports6Used int64 `json:"ports6_used"` + Progress float64 `json:"progress"` + Project string `json:"project"` + ProjectDomain string `json:"project_domain"` + ProjectId string `json:"project_id"` + ProjectMetadata map[string]string `json:"project_metadata"` + ProjectSrc string `json:"project_src"` + Provider string `json:"provider"` + PublicScope string `json:"public_scope"` + PublicSrc string `json:"public_src"` + RdsVnics int64 `json:"rds_vnics"` + Region string `json:"region"` + RegionExtId string `json:"region_ext_id"` + RegionExternalId string `json:"region_external_id"` + RegionId string `json:"region_id"` + ReserveVnics4 int64 `json:"reserve_vnics4"` + ReserveVnics6 int64 `json:"reserve_vnics6"` + Routes []SRoute `json:"routes"` + Schedtags []SchedtagShortDescDetails `json:"schedtags"` + ServerType string `json:"server_type"` + SharedDomains []SharedDomain `json:"shared_domains"` + SharedProjects []SharedProject `json:"shared_projects"` + Source string `json:"source"` + Status string `json:"status"` + Tenant string `json:"tenant"` + TenantId string `json:"tenant_id"` + Total int64 `json:"total"` + Total6 int64 `json:"total6"` + UpdateVersion int64 `json:"update_version"` + UpdatedAt *time.Time `json:"updated_at"` + VlanId int64 `json:"vlan_id"` + Vnics int64 `json:"vnics"` + Vnics4 int64 `json:"vnics4"` + Vnics6 int64 `json:"vnics6"` + Vpc string `json:"vpc"` + VpcExtId string `json:"vpc_ext_id"` + VpcId string `json:"vpc_id"` + Wire string `json:"wire"` + WireId string `json:"wire_id"` + Zone string `json:"zone"` + ZoneId string `json:"zone_id"` +} + +type NetworkListResponse struct { + Limit int64 `json:"limit"` + Offset int64 `json:"offset"` + Networks []NetworkDetails `json:"networks"` + Total int64 `json:"total"` +} + +type ImageDetails struct { + AutoDeleteAt *time.Time `json:"auto_delete_at"` + CanDelete bool `json:"can_delete"` + CanUpdate bool `json:"can_update"` + Checksum string `json:"checksum"` + CreatedAt *time.Time `json:"created_at"` + Deleted bool `json:"deleted"` + DeletedAt *time.Time `json:"deleted_at"` + Description string `json:"description"` + DisableDelete bool `json:"disable_delete"` + DiskFormat string `json:"disk_format"` + DomainId string `json:"domain_id"` + EncryptAlg string `json:"encrypt_alg"` + EncryptKey string `json:"encrypt_key"` + EncryptKeyId string `json:"encrypt_key_id"` + EncryptKeyUser string `json:"encrypt_key_user"` + EncryptKeyUserDomain string `json:"encrypt_key_user_domain"` + EncryptKeyUserDomainId string `json:"encrypt_key_user_domain_id"` + EncryptKeyUserId string `json:"encrypt_key_user_id"` + EncryptStatus string `json:"encrypt_status"` + FastHash string `json:"fast_hash"` + Freezed bool `json:"freezed"` + Id string `json:"id"` + IsData bool `json:"is_data"` + IsEmulated bool `json:"is_emulated"` + IsGuestImage bool `json:"is_guest_image"` + IsPublic bool `json:"is_public"` + IsStandard bool `json:"is_standard"` + IsSystem bool `json:"is_system"` + Location string `json:"location"` + Metadata map[string]string `json:"metadata"` + MinDisk int32 `json:"min_disk"` + MinRam int32 `json:"min_ram"` + Name string `json:"name"` + OsArch string `json:"os_arch"` + OssChecksum string `json:"oss_checksum"` + Owner string `json:"owner"` + PendingDeleted bool `json:"pending_deleted"` + PendingDeletedAt *time.Time `json:"pending_deleted_at"` + Progress float64 `json:"progress"` + Project string `json:"project"` + ProjectDomain string `json:"project_domain"` + ProjectId string `json:"project_id"` + ProjectMetadata map[string]string `json:"project_metadata"` + ProjectSrc string `json:"project_src"` + Properties map[string]string `json:"properties"` + Protected bool `json:"protected"` + PublicScope string `json:"public_scope"` + PublicSrc string `json:"public_src"` + SharedDomains []SharedDomain `json:"shared_domains"` + SharedProjects []SharedProject `json:"shared_projects"` + Size int64 `json:"size"` + Status string `json:"status"` + Tenant string `json:"tenant"` + TenantId string `json:"tenant_id"` + UpdateVersion int64 `json:"update_version"` + UpdatedAt *time.Time `json:"updated_at"` +} + +type ImageListResponse struct { + Limit int64 `json:"limit"` + Offset int64 `json:"offset"` + Images []ImageDetails `json:"images"` + Total int64 `json:"total"` +} + +type ServerSkuDetails struct { + AttachedDiskCount int64 `json:"attached_disk_count"` + AttachedDiskSizeGB int64 `json:"attached_disk_size_gb"` + AttachedDiskType string `json:"attached_disk_type"` + CanDelete bool `json:"can_delete"` + CanUpdate bool `json:"can_update"` + CloudEnv string `json:"cloud_env"` + Cloudregion string `json:"cloudregion"` + CloudregionId string `json:"cloudregion_id"` + CpuArch string `json:"cpu_arch"` + CpuCoreCount int64 `json:"cpu_core_count"` + CreatedAt *time.Time `json:"created_at"` + DataDiskMaxCount int64 `json:"data_disk_max_count"` + DataDiskTypes string `json:"data_disk_types"` + Deleted bool `json:"deleted"` + DeletedAt *time.Time `json:"deleted_at"` + Description string `json:"description"` + Enabled bool `json:"enabled"` + ExternalId string `json:"external_id"` + GpuAttachable bool `json:"gpu_attachable"` + GpuCount string `json:"gpu_count"` + GpuMaxCount int64 `json:"gpu_max_count"` + GpuSpec string `json:"gpu_spec"` + Id string `json:"id"` + ImportedAt *time.Time `json:"imported_at"` + InstanceTypeCategory string `json:"instance_type_category"` + InstanceTypeFamily string `json:"instance_type_family"` + IsEmulated bool `json:"is_emulated"` + LocalCategory string `json:"local_category"` + Md5 string `json:"md5"` + MemorySizeMB int64 `json:"memory_size_mb"` + Metadata map[string]string `json:"metadata"` + Name string `json:"name"` + NicMaxCount int64 `json:"nic_max_count"` + NicType string `json:"nic_type"` + OsName string `json:"os_name"` + PostpaidStatus string `json:"postpaid_status"` + PrepaidStatus string `json:"prepaid_status"` + Progress float64 `json:"progress"` + Provider string `json:"provider"` + Region string `json:"region"` + RegionExtId string `json:"region_ext_id"` + RegionExternalId string `json:"region_external_id"` + RegionId string `json:"region_id"` + Source string `json:"source"` + Status string `json:"status"` + SysDiskMaxSizeGB int64 `json:"sys_disk_max_size_gb"` + SysDiskMinSizeGB int64 `json:"sys_disk_min_size_gb"` + SysDiskResizable bool `json:"sys_disk_resizable"` + SysDiskType string `json:"sys_disk_type"` + TotalGuestCount int64 `json:"total_guest_count"` + UpdateVersion int64 `json:"update_version"` + UpdatedAt *time.Time `json:"updated_at"` + Zone string `json:"zone"` + ZoneExtId string `json:"zone_ext_id"` + ZoneId string `json:"zone_id"` +} + +type ServerSkuListResponse struct { + Limit int64 `json:"limit"` + Offset int64 `json:"offset"` + Serverskus []ServerSkuDetails `json:"serverskus"` + Total int64 `json:"total"` +} + +type StorageHost struct { + HostStatus string `json:"HostStatus"` + Id string `json:"Id"` + Name string `json:"Name"` + Status string `json:"Status"` +} + +type StorageDetails struct { + DiskCount int64 `json:"DiskCount"` + HostCount int64 `json:"HostCount"` + SnapshotCount int64 `json:"SnapshotCount"` + Used int64 `json:"Used"` + Wasted int64 `json:"Wasted"` + Account string `json:"account"` + AccountHealthStatus string `json:"account_health_status"` + AccountId string `json:"account_id"` + AccountReadOnly bool `json:"account_read_only"` + AccountStatus string `json:"account_status"` + ActualCapacityUsed int64 `json:"actual_capacity_used"` + Brand string `json:"brand"` + CanDelete bool `json:"can_delete"` + CanUpdate bool `json:"can_update"` + Capacity int64 `json:"capacity"` + CloudEnv string `json:"cloud_env"` + Cloudregion string `json:"cloudregion"` + CloudregionId string `json:"cloudregion_id"` + Cmtbound float64 `json:"cmtbound"` + CommitBound float64 `json:"commit_bound"` + CommitRate float64 `json:"commit_rate"` + CreatedAt *time.Time `json:"created_at"` + Deleted bool `json:"deleted"` + DeletedAt *time.Time `json:"deleted_at"` + Description string `json:"description"` + DomainId string `json:"domain_id"` + DomainSrc string `json:"domain_src"` + Enabled bool `json:"enabled"` + Environment string `json:"environment"` + ExternalId string `json:"external_id"` + FreeCapacity int64 `json:"free_capacity"` + Hosts []StorageHost `json:"hosts"` + Id string `json:"id"` + ImportedAt *time.Time `json:"imported_at"` + IsEmulated bool `json:"is_emulated"` + IsPublic bool `json:"is_public"` + IsSysDiskStore bool `json:"is_sys_disk_store"` + Manager string `json:"manager"` + ManagerDomain string `json:"manager_domain"` + ManagerDomainId string `json:"manager_domain_id"` + ManagerId string `json:"manager_id"` + ManagerProject string `json:"manager_project"` + ManagerProjectId string `json:"manager_project_id"` + MasterHost string `json:"master_host"` + MasterHostName string `json:"master_host_name"` + MediumType string `json:"medium_type"` + Metadata map[string]string `json:"metadata"` + Name string `json:"name"` + Progress float64 `json:"progress"` + ProjectDomain string `json:"project_domain"` + Provider string `json:"provider"` + PublicScope string `json:"public_scope"` + PublicSrc string `json:"public_src"` + RealTimeUsedCapacity int64 `json:"real_time_used_capacity"` + Region string `json:"region"` + RegionExtId string `json:"region_ext_id"` + RegionExternalId string `json:"region_external_id"` + RegionId string `json:"region_id"` + Reserved int64 `json:"reserved"` + Schedtags []SchedtagShortDescDetails `json:"schedtags"` + SharedDomains []SharedDomain `json:"shared_domains"` + SharedProjects []SharedProject `json:"shared_projects"` + Source string `json:"source"` + Status string `json:"status"` + StorageConf map[string]interface{} `json:"storage_conf"` + StorageType string `json:"storage_type"` + StoragecacheId string `json:"storagecache_id"` + UpdateVersion int64 `json:"update_version"` + UpdatedAt *time.Time `json:"updated_at"` + UsedCapacity int64 `json:"used_capacity"` + VirtualCapacity int64 `json:"virtual_capacity"` + WasteCapacity int64 `json:"waste_capacity"` + Zone string `json:"zone"` + ZoneExtId string `json:"zone_ext_id"` + ZoneId string `json:"zone_id"` +} + +type StorageListResponse struct { + Limit int64 `json:"limit"` + Offset int64 `json:"offset"` + Storages []StorageDetails `json:"storages"` + Total int64 `json:"total"` +} + +type ServerDetails struct { + Account string `json:"account"` + AccountHealthStatus string `json:"account_health_status"` + AccountId string `json:"account_id"` + AccountReadOnly bool `json:"account_read_only"` + AccountStatus string `json:"account_status"` + BackupGuestSync string `json:"backup_guest_sync"` + BackupGuestSyncStatus string `json:"backup_guest_sync_status"` + BackupHostId string `json:"backup_host_id"` + BackupHostName string `json:"backup_host_name"` + BackupHostStatus string `json:"backup_host_status"` + BillingCycle string `json:"billing_cycle"` + BillingType string `json:"billing_type"` + Bios string `json:"bios"` + BootOrder string `json:"boot_order"` + Brand string `json:"brand"` + CanDelete bool `json:"can_delete"` + CanRecycle bool `json:"can_recycle"` + CanUpdate bool `json:"can_update"` + Cdrom interface{} `json:"cdrom"` + CdromSupport bool `json:"cdrom_support"` + CloudEnv string `json:"cloud_env"` + Cloudregion string `json:"cloudregion"` + CloudregionId string `json:"cloudregion_id"` + Containers interface{} `json:"containers"` + CpuNumaPin map[string]interface{} `json:"cpu_numa_pin"` + CpuSockets int64 `json:"cpu_sockets"` + CreatedAt *time.Time `json:"created_at"` + DeleteFailReason interface{} `json:"delete_fail_reason"` + Deleted bool `json:"deleted"` + DeletedAt *time.Time `json:"deleted_at"` + Description string `json:"description"` + DisableDelete bool `json:"disable_delete"` + DiskSizeMb int64 `json:"disk"` + DiskCount int64 `json:"disk_count"` + Disks string `json:"disks"` + DisksInfo interface{} `json:"disks_info"` + DomainId string `json:"domain_id"` + Eip string `json:"eip"` + EipMode string `json:"eip_mode"` + EncryptAlg string `json:"encrypt_alg"` + EncryptKey string `json:"encrypt_key"` + EncryptKeyId string `json:"encrypt_key_id"` + EncryptKeyUser string `json:"encrypt_key_user"` + EncryptKeyUserDomain string `json:"encrypt_key_user_domain"` + EncryptKeyUserDomainId string `json:"encrypt_key_user_domain_id"` + EncryptKeyUserId string `json:"encrypt_key_user_id"` + Environment string `json:"environment"` + ExpiredAt *time.Time `json:"expired_at"` + ExternalId string `json:"external_id"` + ExtraCpuCount int64 `json:"extra_cpu_count"` + FlavorId string `json:"flavor_id"` + Floppy interface{} `json:"floppy"` + FloppySupport bool `json:"floppy_support"` + Freezed bool `json:"freezed"` + GpuCount string `json:"gpu_count"` + GpuModel string `json:"gpu_model"` + Host string `json:"host"` + HostAccessIp string `json:"host_access_ip"` + HostAccessMac string `json:"host_access_mac"` + HostBillingType string `json:"host_billing_type"` + HostEIP string `json:"host_eip"` + HostEnabled bool `json:"host_enabled"` + HostId string `json:"host_id"` + HostStatus string `json:"host_status"` + Hostname string `json:"hostname"` + Hypervisor string `json:"hypervisor"` + Id string `json:"id"` + ImportedAt *time.Time `json:"imported_at"` + Ips []string `json:"ips"` + IsBaremetal bool `json:"is_baremetal"` + IsDefer bool `json:"is_defer"` + IsEmulated bool `json:"is_emulated"` + IsMerge bool `json:"is_merge"` + IsMirror bool `json:"is_mirror"` + IsPublic bool `json:"is_public"` + IsSystem bool `json:"is_system"` + KeypairId string `json:"keypair_id"` + Manager string `json:"manager"` + ManagerDomain string `json:"manager_domain"` + ManagerDomainId string `json:"manager_domain_id"` + ManagerId string `json:"manager_id"` + ManagerProject string `json:"manager_project"` + ManagerProjectId string `json:"manager_project_id"` + MemoryPinned bool `json:"memory_pinned"` + Metadata map[string]string `json:"metadata"` + Mmemc interface{} `json:"mmemc"` + Name string `json:"name"` + NicType string `json:"nic_type"` + Nics interface{} `json:"nics"` + NSPSConfig map[string]interface{} `json:"nsps_config"` + OsArch string `json:"os_arch"` + OsFullName string `json:"os_full_name"` + OsName string `json:"os_name"` + OsType string `json:"os_type"` + PendingDeleted bool `json:"pending_deleted"` + PendingDeletedAt *time.Time `json:"pending_deleted_at"` + PowerStates string `json:"power_states"` + Progress float64 `json:"progress"` + Project string `json:"project"` + ProjectDomain string `json:"project_domain"` + ProjectId string `json:"project_id"` + ProjectMetadata map[string]string `json:"project_metadata"` + ProjectSrc string `json:"project_src"` + Provider string `json:"provider"` + PublicIp string `json:"public_ip"` + PublicScope string `json:"public_scope"` + PublicSrc string `json:"public_src"` + Rds bool `json:"rds"` + RecoveryMode string `json:"recovery_mode"` + ReorderMaster bool `json:"reorder_master"` + Schedtags []SchedtagShortDescDetails `json:"schedtags"` + SecurityGroup string `json:"security_group"` + SecurityGroupId string `json:"security_group_id"` + SecurityGroups interface{} `json:"security_groups"` + SharedDomains []SharedDomain `json:"shared_domains"` + SharedProjects []SharedProject `json:"shared_projects"` + ShutdownBehavior string `json:"shutdown_behavior"` + SourceOsDist string `json:"source_os_dist"` + Source string `json:"source"` + Status string `json:"status"` + StorageId string `json:"storage_id"` + StorageType string `json:"storage_type"` + SystemVmtypeName string `json:"system_vmtype_name"` + Tenant string `json:"tenant"` + TenantId string `json:"tenant_id"` + UpdateVersion int64 `json:"update_version"` + UpdatedAt *time.Time `json:"updated_at"` + UpgradeStatus string `json:"upgrade_status"` + UpdateFailReason interface{} `json:"update_fail_reason"` + UserData string `json:"user_data"` + VcpuCount int64 `json:"vcpu_count"` + VdiBrokerStuff map[string]interface{} `json:"vdi_broker_stuff"` + VdiConfig map[string]interface{} `json:"vdi_config"` + VditConfig map[string]interface{} `json:"vdit_config"` + VmemSize int64 `json:"vmem_size"` + VMEMSizeMb int64 `json:"vmem_size_mb"` + Vpc string `json:"vpc"` + VpcId string `json:"vpc_id"` + Zone string `json:"zone"` + ZoneId string `json:"zone_id"` +} + +type ServerListResponse struct { + Limit int64 `json:"limit"` + Offset int64 `json:"offset"` + Servers []ServerDetails `json:"servers"` + Total int64 `json:"total"` +} + +type ServerStartRequest struct { + AutoPrepaid bool + QemuVersion string +} + +type ServerStopRequest struct { + IsForce bool + StopCharging bool + TimeoutSecs int64 +} + +type ServerRestartRequest struct { + IsForce bool +} + +type ServerOperationResponse struct { + Id string `json:"id"` + Name string `json:"name"` + Status string `json:"status"` + TaskId string `json:"task_id"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` + Operation string `json:"operation,omitempty"` +} + +type ServerResetPasswordRequest struct { + Password string + ResetPassword bool + AutoStart bool + Username string +} + +type ServerDeleteRequest struct { + OverridePendingDelete bool + Purge bool + DeleteSnapshots bool + DeleteEip bool + DeleteDisks bool +} + +type CreateServerRequest struct { + Name string + VcpuCount int64 + VmemSize int64 + ImageId string + DiskSize int64 + NetworkId string + ServerskuId string + Count int + Password string + AutoStart bool + BillingType string + Duration string + Description string + Hostname string + Hypervisor string + Metadata map[string]string + SecgroupId string + Secgroups []string + UserData string + KeypairId string + ProjectId string + ZoneId string + RegionId string + DisableDelete bool + BootOrder string + DataDisks []DiskConfig +} + +type DiskConfig struct { + ImageId string + Size int64 + DiskType string +} + +type ServerCreateResponseData struct { + Servers []ServerCreateInfo `json:"servers"` +} + +type ServerCreateInfo struct { + ID string `json:"id"` + Name string `json:"name"` + Status string `json:"status"` + TaskID string `json:"task_id"` +} + +type CreateServerResponse struct { + Status int `json:"status"` + Message string `json:"msg"` + Data ServerCreateResponseData `json:"data"` +} + +type MonitorResponse struct { + Status int `json:"status"` + Data MonitorResponseData `json:"data"` +} + +type MonitorResponseData struct { + Metrics []MetricData `json:"metrics"` +} + +type MetricData struct { + Metric string `json:"metric"` + Unit string `json:"unit"` + Values []MetricValue `json:"values"` +} + +type MetricValue struct { + Timestamp int64 `json:"timestamp"` + Value float64 `json:"value"` +} + +type ServerStatsResponse struct { + Status int `json:"status"` + Data ServerStatsData `json:"data"` +} + +type ServerStatsData struct { + CPUUsage float64 `json:"cpu_usage"` + MemUsage float64 `json:"mem_usage"` + DiskUsage float64 `json:"disk_usage"` + NetBpsRx int64 `json:"net_bps_rx"` + NetBpsTx int64 `json:"net_bps_tx"` + UpdatedAt string `json:"updated_at"` +} diff --git a/pkg/mcp-server/options/doc.go b/pkg/mcp-server/options/doc.go new file mode 100644 index 0000000000..48ef7763fe --- /dev/null +++ b/pkg/mcp-server/options/doc.go @@ -0,0 +1 @@ +package options // import "yunion.io/x/onecloud/pkg/mcp-server/options" diff --git a/pkg/mcp-server/options/options.go b/pkg/mcp-server/options/options.go new file mode 100644 index 0000000000..07b1fa81c7 --- /dev/null +++ b/pkg/mcp-server/options/options.go @@ -0,0 +1,37 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package options + +import ( + common_options "yunion.io/x/onecloud/pkg/cloudcommon/options" +) + +type MCPServerOptions struct { + common_options.CommonOptions + // 服务基础信息 + MCPServerName string `help:"MCP service name"` + MCPServerVersion string `help:"MCP service version"` + MCPServerDescription string `help:"MCP service description"` + + // 认证服务集成 + IdentityBaseURL string `help:"Authentication service entry URL"` + + // 连接超时配置 + Timeout int `help:"SDK connection timeout to cloudpods service (seconds)" default:"30"` +} + +var ( + Options MCPServerOptions +) diff --git a/pkg/mcp-server/registry/doc.go b/pkg/mcp-server/registry/doc.go new file mode 100644 index 0000000000..f284975509 --- /dev/null +++ b/pkg/mcp-server/registry/doc.go @@ -0,0 +1 @@ +package registry // import "yunion.io/x/onecloud/pkg/mcp-server/registry" diff --git a/pkg/mcp-server/registry/registry.go b/pkg/mcp-server/registry/registry.go new file mode 100644 index 0000000000..37dd2b0529 --- /dev/null +++ b/pkg/mcp-server/registry/registry.go @@ -0,0 +1,89 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package registry + +import ( + "fmt" + "sync" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" + + "yunion.io/x/log" +) + +type Registry struct { + mu sync.RWMutex + tools map[string]*ToolRegistration + mcpServer *server.MCPServer + initialized bool +} + +type ToolRegistration struct { + Tool mcp.Tool + Handler server.ToolHandlerFunc +} + +func NewRegistry() *Registry { + return &Registry{ + tools: make(map[string]*ToolRegistration), + } +} + +// Initialize 使用MCP服务器初始化注册中心 +func (r *Registry) Initialize(mcpServer *server.MCPServer) error { + r.mu.Lock() + defer r.mu.Unlock() + + if r.initialized { + return fmt.Errorf("Fail to init register ") + } + + r.mcpServer = mcpServer + + // 将所有已注册的工具添加到MCP服务器 + for _, registration := range r.tools { + r.mcpServer.AddTool(registration.Tool, registration.Handler) + } + + r.initialized = true + + return nil +} + +// RegisterTool 注册单个工具 +func (r *Registry) RegisterTool(toolName string, tool mcp.Tool, handler server.ToolHandlerFunc) error { + r.mu.Lock() + defer r.mu.Unlock() + + if _, exists := r.tools[toolName]; exists { + return fmt.Errorf("Tool already register: '%s' ", toolName) + } + + registration := &ToolRegistration{ + Tool: tool, + Handler: handler, + } + + r.tools[toolName] = registration + log.Infof("Tool register successfully: %s", toolName) + + // 如果MCP服务器已设置,立即注册到服务器 + if r.mcpServer != nil { + r.mcpServer.AddTool(tool, handler) + } + + return nil +} diff --git a/pkg/mcp-server/server/doc.go b/pkg/mcp-server/server/doc.go new file mode 100644 index 0000000000..a6ae5130c8 --- /dev/null +++ b/pkg/mcp-server/server/doc.go @@ -0,0 +1 @@ +package server // import "yunion.io/x/onecloud/pkg/mcp-server/server" diff --git a/pkg/mcp-server/server/server.go b/pkg/mcp-server/server/server.go new file mode 100644 index 0000000000..d149b21fae --- /dev/null +++ b/pkg/mcp-server/server/server.go @@ -0,0 +1,157 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package server + +import ( + "fmt" + + "github.com/mark3labs/mcp-go/server" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/options" + "yunion.io/x/onecloud/pkg/mcp-server/registry" + "yunion.io/x/onecloud/pkg/mcp-server/tools" +) + +// CloudpodsMCPServer 是 MCP 服务器的核心结构体,包含配置、日志、MCP 实例、注册中心和工具列表 +type CloudpodsMCPServer struct { + mcpServer *server.MCPServer + registry *registry.Registry + tools []tools.Tool +} + +// NewServer 创建一个新的 Cloudpods MCP 服务器实例,初始化 MCP 服务器和注册中心,并创建所有工具 +func NewServer() *CloudpodsMCPServer { + + // 创建mcp server对象 + mcpServer := server.NewMCPServer( + options.Options.MCPServerName, + options.Options.MCPServerVersion, + server.WithToolCapabilities(false), + server.WithRecovery(), + ) + + // 创建注册中心对象 + reg := registry.NewRegistry() + + var allTools []tools.Tool + + // 创建mcclient sdk的适配器对象 + adapter := adapters.NewCloudpodsAdapter() + + // 创建具体的工具函数对象 + // 用于查询资源的工具函数 + regionsTool := tools.NewCloudpodsRegionsTool(adapter) + vpcsTool := tools.NewCloudpodsVPCsTool(adapter) + networksTool := tools.NewCloudpodsNetworksTool(adapter) + imagesTool := tools.NewCloudpodsImagesTool(adapter) + skusTool := tools.NewCloudpodsServerSkusTool(adapter) + storagesTool := tools.NewCloudpodsStoragesTool(adapter) + serversTool := tools.NewCloudpodsServersTool(adapter) + + // 用于操作资源的工具函数 + serverStartTool := tools.NewCloudpodsServerStartTool(adapter) + serverStopTool := tools.NewCloudpodsServerStopTool(adapter) + serverRestartTool := tools.NewCloudpodsServerRestartTool(adapter) + serverResetPasswordTool := tools.NewCloudpodsServerResetPasswordTool(adapter) + serverDeleteTool := tools.NewCloudpodsServerDeleteTool(adapter) + serverCreateTool := tools.NewCloudpodsServerCreateTool(adapter) + serverMonitorTool := tools.NewCloudpodsServerMonitorTool(adapter) + serverStatsTool := tools.NewCloudpodsServerStatsTool(adapter) + + // 将所有的工具函数存储到一个切片中 + allTools = append( + allTools, + regionsTool, + vpcsTool, + networksTool, + imagesTool, + skusTool, + storagesTool, + serversTool, + + serverStartTool, + serverStopTool, + serverRestartTool, + serverResetPasswordTool, + serverDeleteTool, + serverCreateTool, + serverMonitorTool, + serverStatsTool, + ) + + return &CloudpodsMCPServer{ + mcpServer: mcpServer, + registry: reg, + tools: allTools, + } +} + +// Initialize 初始化注册中心和所有工具 +func (s *CloudpodsMCPServer) Initialize() error { + + // 初始化工具注册中心 + if err := s.registry.Initialize(s.mcpServer); err != nil { + return fmt.Errorf("初始化工具注册中心失败: %w", err) + } + + // 注册内置工具 + if err := s.registerAllTools(); err != nil { + return fmt.Errorf("注册内置工具失败: %w", err) + } + + return nil +} + +// registerAllTools 将所有工具注册到注册中心 +func (s *CloudpodsMCPServer) registerAllTools() error { + for _, tool := range s.tools { + // 注册距离查询工具 + if err := s.registry.RegisterTool( + tool.GetName(), + tool.GetTool(), + tool.Handle, + ); err != nil { + return fmt.Errorf("注册工具失败: %w", err) + } + } + + log.Infof("All tools register completed") + return nil +} + +// Start 以sse模式启动 mcp 服务 +func (s *CloudpodsMCPServer) Start() error { + + if err := server.NewSSEServer(s.mcpServer).Start(fmt.Sprintf("%s:%d", options.Options.Address, options.Options.Port)); err != nil { + return err + } + log.Infof("Start mcp server successfully") + + return nil +} + +// StartStdio 以stdio模式启动 mcp 服务 +func (s *CloudpodsMCPServer) StartStdio() error { + + err := server.ServeStdio(s.mcpServer) + if err != nil { + return err + } + log.Infof("Start mcp server successfully") + return nil +} diff --git a/pkg/mcp-server/service/doc.go b/pkg/mcp-server/service/doc.go new file mode 100644 index 0000000000..b15c1d82e2 --- /dev/null +++ b/pkg/mcp-server/service/doc.go @@ -0,0 +1 @@ +package service // import "yunion.io/x/onecloud/pkg/mcp-server/service" diff --git a/pkg/mcp-server/service/service.go b/pkg/mcp-server/service/service.go new file mode 100644 index 0000000000..1905556038 --- /dev/null +++ b/pkg/mcp-server/service/service.go @@ -0,0 +1,44 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "os" + + "yunion.io/x/log" + + common_options "yunion.io/x/onecloud/pkg/cloudcommon/options" + "yunion.io/x/onecloud/pkg/mcp-server/options" + "yunion.io/x/onecloud/pkg/mcp-server/server" +) + +func StartService() { + + opts := &options.Options + common_options.ParseOptions(opts, os.Args, "mcpserver.conf", "mcpserver") + + // 创建服务器 + srv := server.NewServer() + + // 初始化服务器 + if err := srv.Initialize(); err != nil { + log.Fatalf("Fail to init mcp server: %s", err) + } + + // 启动服务器 + if err := srv.Start(); err != nil { + log.Fatalf("Fail to start mcp server: %s", err) + } +} diff --git a/pkg/mcp-server/tools/cloudpods_images_tool.go b/pkg/mcp-server/tools/cloudpods_images_tool.go new file mode 100644 index 0000000000..32a7dc7457 --- /dev/null +++ b/pkg/mcp-server/tools/cloudpods_images_tool.go @@ -0,0 +1,240 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// CloudpodsImagesTool 是一个用于查询 Cloudpods 镜像列表的工具 +// 它封装了 Cloudpods 适配器和日志记录器 +type CloudpodsImagesTool struct { + // adapter 用于与 Cloudpods API 进行交互 + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsImagesTool 创建一个新的 CloudpodsImagesTool 实例 +// 参数: +// - adapter: Cloudpods 适配器实例,用于与 Cloudpods API 交互 +// +// 返回值: +// - *CloudpodsImagesTool: 新创建的 CloudpodsImagesTool 实例 +func NewCloudpodsImagesTool(adapter *adapters.CloudpodsAdapter) *CloudpodsImagesTool { + return &CloudpodsImagesTool{ + adapter: adapter, + } +} + +// GetTool 定义并返回 Cloudpods 镜像列表查询工具的元数据 +// 该工具允许用户查询 Cloudpods 中的磁盘镜像列表,并支持多种查询参数 +// 返回值: +// - mcp.Tool: 定义了工具名称、描述和参数的工具对象 +func (c *CloudpodsImagesTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_list_images", + mcp.WithDescription("查询Cloudpods磁盘镜像列表,获取系统镜像信息"), + mcp.WithString("limit", mcp.Description("返回结果数量限制,默认为20")), + mcp.WithString("offset", mcp.Description("返回结果偏移量,默认为0")), + mcp.WithString("search", mcp.Description("搜索关键词,可以按镜像名称搜索")), + mcp.WithString("os_types", mcp.Description("操作系统类型,多个用逗号分隔,如:Linux,Windows,FreeBSD")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理 Cloudpods 镜像列表查询请求 +// 该方法解析请求参数,调用适配器查询镜像列表,并格式化返回结果 +// 参数: +// - ctx: 上下文对象,用于控制请求生命周期 +// - req: 工具调用请求对象,包含查询参数 +// +// 返回值: +// - *mcp.CallToolResult: 格式化后的镜像列表查询结果 +// - error: 如果查询过程中发生错误,则返回相应的错误信息 +func (c *CloudpodsImagesTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 设置默认的查询结果数量限制为20 + limit := 20 + // 如果请求中包含limit参数且为有效正整数,则使用该值 + if limitStr := req.GetString("limit", ""); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + // 设置默认的查询偏移量为0 + offset := 0 + // 如果请求中包含offset参数且为有效非负整数,则使用该值 + if offsetStr := req.GetString("offset", ""); offsetStr != "" { + if parsedOffset, err := strconv.Atoi(offsetStr); err == nil && parsedOffset >= 0 { + offset = parsedOffset + } + } + + // 获取搜索关键词参数 + search := req.GetString("search", "") + + // 解析操作系统类型参数,支持多个类型用逗号分隔 + var osTypes []string + if osTypesStr := req.GetString("os_types", ""); osTypesStr != "" { + osTypes = strings.Split(osTypesStr, ",") + for i, osType := range osTypes { + osTypes[i] = strings.TrimSpace(osType) + } + } + + // 获取访问凭证 + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器查询镜像列表 + imagesResponse, err := c.adapter.ListImages(limit, offset, search, osTypes, ak, sk) + if err != nil { + log.Errorf("Fail to query image: %s", err) + return nil, fmt.Errorf("fail to query image: %w", err) + } + + // 格式化查询结果 + formattedResult := c.formatImagesResult(imagesResponse, limit, offset, search, osTypes) + + // 将结果序列化为JSON格式 + resultJSON, err := json.MarshalIndent(formattedResult, "", " ") + if err != nil { + log.Errorf("Fail to serialize result: %s", err) + return nil, fmt.Errorf("fail to serialize result: %w", err) + } + + // 返回格式化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回工具的名称标识符 +// 返回值: +// - string: 工具名称,用于唯一标识该工具 +func (c *CloudpodsImagesTool) GetName() string { + return "cloudpods_list_images" +} + +// formatImagesResult 格式化镜像列表查询结果 +// 该方法将从适配器获取的原始镜像数据转换为结构化的响应格式,包含查询信息、镜像详情和摘要信息 +// 参数: +// - response: 从适配器获取的原始镜像列表响应数据 +// - limit: 查询结果数量限制 +// - offset: 查询偏移量 +// - search: 搜索关键词 +// - osTypes: 操作系统类型过滤条件 +// +// 返回值: +// - map[string]interface{}: 格式化后的镜像列表数据,包含查询信息、镜像详情和摘要 +func (c *CloudpodsImagesTool) formatImagesResult(response *models.ImageListResponse, limit, offset int, search string, osTypes []string) map[string]interface{} { + // 初始化格式化结果结构 + formatted := map[string]interface{}{ + // 查询信息部分,包含查询参数和结果统计 + "query_info": map[string]interface{}{ + "limit": limit, + "offset": offset, + "search": search, + "os_types": osTypes, + "total": response.Total, + "count": len(response.Images), + }, + // 镜像列表部分,初始化为空数组 + "images": make([]map[string]interface{}, 0, len(response.Images)), + } + + // 遍历原始镜像数据,提取每个镜像的详细信息 + for _, image := range response.Images { + // 构造单个镜像的详细信息 + imageInfo := map[string]interface{}{ + "id": image.Id, + "name": image.Name, + "description": image.Description, + "status": image.Status, + "disk_format": image.DiskFormat, + "size": image.Size, + "checksum": image.Checksum, + "oss_checksum": image.OssChecksum, + "fast_hash": image.FastHash, + "location": image.Location, + "os_arch": image.OsArch, + "min_disk": image.MinDisk, + "min_ram": image.MinRam, + "is_data": image.IsData, + "is_guest_image": image.IsGuestImage, + "is_public": image.IsPublic, + "is_standard": image.IsStandard, + "is_system": image.IsSystem, + "is_emulated": image.IsEmulated, + "protected": image.Protected, + "disable_delete": image.DisableDelete, + "freezed": image.Freezed, + "pending_deleted": image.PendingDeleted, + "pending_deleted_at": image.PendingDeletedAt, + "auto_delete_at": image.AutoDeleteAt, + "encrypt_alg": image.EncryptAlg, + "encrypt_key": image.EncryptKey, + "encrypt_key_id": image.EncryptKeyId, + "encrypt_key_user": image.EncryptKeyUser, + "encrypt_key_user_domain": image.EncryptKeyUserDomain, + "encrypt_key_user_domain_id": image.EncryptKeyUserDomainId, + "encrypt_key_user_id": image.EncryptKeyUserId, + "encrypt_status": image.EncryptStatus, + "owner": image.Owner, + "project": image.Project, + "project_id": image.ProjectId, + "project_domain": image.ProjectDomain, + "project_metadata": image.ProjectMetadata, + "project_src": image.ProjectSrc, + "tenant": image.Tenant, + "tenant_id": image.TenantId, + "domain_id": image.DomainId, + "public_scope": image.PublicScope, + "public_src": image.PublicSrc, + "shared_domains": image.SharedDomains, + "shared_projects": image.SharedProjects, + "properties": image.Properties, + "metadata": image.Metadata, + "progress": image.Progress, + "can_delete": image.CanDelete, + "can_update": image.CanUpdate, + "update_version": image.UpdateVersion, + "created_at": image.CreatedAt, + "updated_at": image.UpdatedAt, + } + // 将镜像信息添加到结果数组中 + formatted["images"] = append(formatted["images"].([]map[string]interface{}), imageInfo) + } + + // 构造结果摘要信息 + formatted["summary"] = map[string]interface{}{ + "total_images": response.Total, + "returned_count": len(response.Images), + "has_more": response.Total > int64(offset+len(response.Images)), + "next_offset": offset + len(response.Images), + } + + // 返回格式化后的完整结果 + return formatted +} diff --git a/pkg/mcp-server/tools/cloudpods_networks_tool.go b/pkg/mcp-server/tools/cloudpods_networks_tool.go new file mode 100644 index 0000000000..83f0f8950d --- /dev/null +++ b/pkg/mcp-server/tools/cloudpods_networks_tool.go @@ -0,0 +1,265 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + + "github.com/mark3labs/mcp-go/mcp" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// CloudpodsNetworksTool 是一个用于查询 Cloudpods 网络列表的工具 +type CloudpodsNetworksTool struct { + // adapter 用于与 Cloudpods API 进行交互 + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsNetworksTool 创建一个新的 CloudpodsNetworksTool 实例 +// adapter: 用于与 Cloudpods API 进行交互的适配器 +// 返回值: 指向新创建的 CloudpodsNetworksTool 实例的指针 +func NewCloudpodsNetworksTool(adapter *adapters.CloudpodsAdapter) *CloudpodsNetworksTool { + return &CloudpodsNetworksTool{ + adapter: adapter, + } +} + +// GetTool 定义并返回网络列表查询工具的元数据 +// 该工具用于查询Cloudpods中的IP子网列表,获取网络配置信息 +// 支持的参数包括: +// - limit: 返回结果数量限制,默认为20 +// - offset: 返回结果偏移量,默认为0 +// - search: 搜索关键词,可以按网络名称搜索 +// - vpc_id: 过滤指定VPC的网络资源 +// - ak: 用户登录cloudpods后获取的access key +// - sk: 用户登录cloudpods后获取的secret key +func (c *CloudpodsNetworksTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_list_networks", + mcp.WithDescription("查询Cloudpods IP子网列表,获取网络配置信息"), + mcp.WithString("limit", mcp.Description("返回结果数量限制,默认为20")), + mcp.WithString("offset", mcp.Description("返回结果偏移量,默认为0")), + mcp.WithString("search", mcp.Description("搜索关键词,可以按网络名称搜索")), + mcp.WithString("vpc_id", mcp.Description("过滤指定VPC的网络资源")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理网络列表查询请求 +// ctx: 控制请求生命周期的上下文 +// req: 包含查询参数的请求对象 +// 返回值: 包含查询结果的工具结果对象或错误信息 +func (c *CloudpodsNetworksTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 设置默认查询限制为20 + limit := 20 + if limitStr := req.GetString("limit", ""); limitStr != "" { + // 解析limit参数,如果解析成功且大于0,则使用解析后的值 + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + // 设置默认偏移量为0 + offset := 0 + if offsetStr := req.GetString("offset", ""); offsetStr != "" { + // 解析offset参数,如果解析成功且大于等于0,则使用解析后的值 + if parsedOffset, err := strconv.Atoi(offsetStr); err == nil && parsedOffset >= 0 { + offset = parsedOffset + } + } + + // 获取搜索关键词和VPC ID参数 + search := req.GetString("search", "") + vpcId := req.GetString("vpc_id", "") + + // 获取访问凭证 + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器获取网络列表 + networksResponse, err := c.adapter.ListNetworks(limit, offset, search, vpcId, ak, sk) + if err != nil { + log.Errorf("Fail to query network: %s", err) + return nil, fmt.Errorf("fail to query network: %w", err) + } + + // 格式化查询结果 + formattedResult := c.formatNetworksResult(networksResponse, limit, offset, search, vpcId) + + // 将结果序列化为JSON格式 + resultJSON, err := json.MarshalIndent(formattedResult, "", " ") + if err != nil { + log.Errorf("Fail to serialize result: %s", err) + return nil, fmt.Errorf("fail to serialize result: %w", err) + } + + // 返回格式化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回工具的名称标识符 +// 返回值: 工具名称字符串,用于唯一标识该工具 +func (c *CloudpodsNetworksTool) GetName() string { + return "cloudpods_list_networks" +} + +// formatNetworksResult 格式化网络列表查询结果 +// response: 从适配器获取的原始网络数据 +// limit: 查询限制数量 +// offset: 查询偏移量 +// search: 搜索关键词 +// vpcId: VPC ID过滤条件 +// 返回值: 格式化后的网络列表数据,包含查询信息、网络列表和摘要信息 +func (c *CloudpodsNetworksTool) formatNetworksResult(response *models.NetworkListResponse, limit, offset int, search, vpcId string) map[string]interface{} { + // 初始化结果结构,包含查询信息和网络列表 + formatted := map[string]interface{}{ + "query_info": map[string]interface{}{ + "limit": limit, + "offset": offset, + "search": search, + "vpc_id": vpcId, + "total": response.Total, + "count": len(response.Networks), + }, + "networks": make([]map[string]interface{}, 0, len(response.Networks)), + } + + // 遍历原始网络数据,构造每个网络的详细信息 + for _, network := range response.Networks { + // 构造单个网络信息 + networkInfo := map[string]interface{}{ + "id": network.Id, + "name": network.Name, + "description": network.Description, + "status": network.Status, + "guest_ip_start": network.GuestIpStart, + "guest_ip_end": network.GuestIpEnd, + "guest_ip_mask": network.GuestIpMask, + "guest_gateway": network.GuestGateway, + "guest_dns": network.GuestDns, + "guest_dhcp": network.GuestDhcp, + "guest_ntp": network.GuestNtp, + "guest_domain": network.GuestDomain, + "guest_ip6_start": network.GuestIp6Start, + "guest_ip6_end": network.GuestIp6End, + "guest_ip6_mask": network.GuestIp6Mask, + "guest_gateway6": network.GuestGateway6, + "guest_dns6": network.GuestDns6, + "guest_domain6": network.GuestDomain6, + "vpc": network.Vpc, + "vpc_id": network.VpcId, + "vpc_ext_id": network.VpcExtId, + "wire": network.Wire, + "wire_id": network.WireId, + "zone": network.Zone, + "zone_id": network.ZoneId, + "cloudregion": network.Cloudregion, + "cloudregion_id": network.CloudregionId, + "region": network.Region, + "region_id": network.RegionId, + "provider": network.Provider, + "brand": network.Brand, + "cloud_env": network.CloudEnv, + "environment": network.Environment, + "external_id": network.ExternalId, + "account": network.Account, + "account_id": network.AccountId, + "account_status": network.AccountStatus, + "account_health_status": network.AccountHealthStatus, + "manager": network.Manager, + "manager_id": network.ManagerId, + "manager_domain": network.ManagerDomain, + "manager_domain_id": network.ManagerDomainId, + "manager_project": network.ManagerProject, + "manager_project_id": network.ManagerProjectId, + "server_type": network.ServerType, + "alloc_policy": network.AllocPolicy, + "vlan_id": network.VlanId, + "bgp_type": network.BgpType, + "is_auto_alloc": network.IsAutoAlloc, + "is_classic": network.IsClassic, + "is_default_vpc": network.IsDefaultVpc, + "is_public": network.IsPublic, + "is_system": network.IsSystem, + "is_emulated": network.IsEmulated, + "exit": network.Exit, + "freezed": network.Freezed, + "pending_deleted": network.PendingDeleted, + "pending_deleted_at": network.PendingDeletedAt, + "ports": network.Ports, + "ports_used": network.PortsUsed, + "ports6_used": network.Ports6Used, + "total": network.Total, + "total6": network.Total6, + "vnics": network.Vnics, + "vnics4": network.Vnics4, + "vnics6": network.Vnics6, + "bm_vnics": network.BmVnics, + "bm_reused_vnics": network.BmReusedVnics, + "eip_vnics": network.EipVnics, + "group_vnics": network.GroupVnics, + "lb_vnics": network.LbVnics, + "nat_vnics": network.NatVnics, + "networkinterface_vnics": network.NetworkinterfaceVnics, + "rds_vnics": network.RdsVnics, + "reserve_vnics4": network.ReserveVnics4, + "reserve_vnics6": network.ReserveVnics6, + "routes": network.Routes, + "schedtags": network.Schedtags, + "additional_wires": network.AdditionalWires, + "shared_domains": network.SharedDomains, + "shared_projects": network.SharedProjects, + "project": network.Project, + "project_id": network.ProjectId, + "project_domain": network.ProjectDomain, + "project_metadata": network.ProjectMetadata, + "project_src": network.ProjectSrc, + "tenant": network.Tenant, + "tenant_id": network.TenantId, + "domain_id": network.DomainId, + "public_scope": network.PublicScope, + "public_src": network.PublicSrc, + "source": network.Source, + "progress": network.Progress, + "can_delete": network.CanDelete, + "can_update": network.CanUpdate, + "metadata": network.Metadata, + "created_at": network.CreatedAt, + "updated_at": network.UpdatedAt, + "imported_at": network.ImportedAt, + } + // 将网络信息添加到结果数组中 + formatted["networks"] = append(formatted["networks"].([]map[string]interface{}), networkInfo) + } + + // 构造摘要信息 + formatted["summary"] = map[string]interface{}{ + "total_networks": response.Total, + "returned_count": len(response.Networks), + "has_more": response.Total > int64(offset+len(response.Networks)), + "next_offset": offset + len(response.Networks), + } + + // 返回格式化后的结果 + return formatted +} diff --git a/pkg/mcp-server/tools/cloudpods_regions_tool.go b/pkg/mcp-server/tools/cloudpods_regions_tool.go new file mode 100644 index 0000000000..89e65eae9e --- /dev/null +++ b/pkg/mcp-server/tools/cloudpods_regions_tool.go @@ -0,0 +1,192 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + + "github.com/mark3labs/mcp-go/mcp" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// CloudpodsRegionsTool 是用于查询 Cloudpods 区域列表的工具 +type CloudpodsRegionsTool struct { + // adapter 用于与 Cloudpods API 进行交互 + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsRegionsTool 创建一个新的 Cloudpods 区域查询工具 +// adapter: 用于与 Cloudpods API 进行交互的适配器 +// 返回值: 指向新创建的 CloudpodsRegionsTool 实例的指针 +func NewCloudpodsRegionsTool(adapter *adapters.CloudpodsAdapter) *CloudpodsRegionsTool { + return &CloudpodsRegionsTool{ + adapter: adapter, + } +} + +// GetTool 返回 MCP 工具定义,用于查询 Cloudpods 区域列表 +// 该工具用于查询Cloudpods中的区域列表,获取所有可用的云区域信息 +// 支持的参数包括: +// - limit: 返回结果数量限制,默认为50 +// - offset: 返回结果偏移量,默认为0 +// - search: 搜索关键词,可以按区域名称搜索 +// - provider: 云平台提供商,例如:aws、azure、aliyun等 +// - ak: 用户登录cloudpods后获取的access key +// - sk: 用户登录cloudpods后获取的secret key +func (c *CloudpodsRegionsTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_list_regions", + mcp.WithDescription("查询Cloudpods区域列表,获取所有可用的云区域信息"), + mcp.WithString("limit", mcp.Description("返回结果数量限制,默认为50")), + mcp.WithString("offset", mcp.Description("返回结果偏移量,默认为0")), + mcp.WithString("search", mcp.Description("搜索关键词,可以按区域名称搜索")), + mcp.WithString("provider", mcp.Description("云平台提供商,例如:aws、azure、aliyun等")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理查询 Cloudpods 区域列表的请求 +// ctx: 控制请求生命周期的上下文 +// req: 包含查询参数的请求对象 +// 返回值: 包含查询结果的工具结果对象或错误信息 +func (c *CloudpodsRegionsTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 设置默认查询限制为50 + limit := 50 + if limitStr := req.GetString("limit", ""); limitStr != "" { + // 解析limit参数,如果解析成功且大于0,则使用解析后的值 + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + // 设置默认偏移量为0 + offset := 0 + if offsetStr := req.GetString("offset", ""); offsetStr != "" { + // 解析offset参数,如果解析成功且大于等于0,则使用解析后的值 + if parsedOffset, err := strconv.Atoi(offsetStr); err == nil && parsedOffset >= 0 { + offset = parsedOffset + } + } + + // 获取搜索关键词和提供商参数 + search := req.GetString("search", "") + provider := req.GetString("provider", "") + + // 获取访问凭证 + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器获取区域列表 + regionsResponse, err := c.adapter.ListCloudRegions(ctx, limit, offset, search, provider, ak, sk) + if err != nil { + log.Errorf("Fail to query region: %s", err) + return nil, fmt.Errorf("fail to query region: %w", err) + } + + // 格式化查询结果 + formattedResult := c.formatRegionsResult(regionsResponse, limit, offset, search, provider) + + // 将结果序列化为JSON格式 + resultJSON, err := json.MarshalIndent(formattedResult, "", " ") + if err != nil { + log.Errorf("Fail to serialize result: %s", err) + return nil, fmt.Errorf("fail to serialize result: %w", err) + } + // 返回格式化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// formatRegionsResult 格式化区域列表查询结果 +// response: 从适配器获取的原始区域数据 +// limit: 查询限制数量 +// offset: 查询偏移量 +// search: 搜索关键词 +// provider: 云平台提供商 +// 返回值: 格式化后的区域列表数据,包含查询信息、区域列表和摘要信息 +func (c *CloudpodsRegionsTool) formatRegionsResult(response *models.CloudregionListResponse, limit, offset int, search, provider string) map[string]interface{} { + // 初始化结果结构,包含查询信息和区域列表 + formatted := map[string]interface{}{ + "query_info": map[string]interface{}{ + "limit": limit, + "offset": offset, + "search": search, + "provider": provider, + "total": response.Total, + "count": len(response.Cloudregions), + }, + "cloudregions": make([]map[string]interface{}, 0, len(response.Cloudregions)), + } + + // 遍历原始区域数据,构造每个区域的详细信息 + for _, region := range response.Cloudregions { + // 构造单个区域信息 + regionInfo := map[string]interface{}{ + "id": region.Id, + "name": region.Name, + "description": region.Description, + "provider": region.Provider, + "cloud_env": region.CloudEnv, + "environment": region.Environment, + "city": region.City, + "country_code": region.CountryCode, + "latitude": region.Latitude, + "longitude": region.Longitude, + "status": region.Status, + "enabled": region.Enabled, + "external_id": region.ExternalId, + "guest_count": region.GuestCount, + "guest_increment_count": region.GuestIncrementCount, + "network_count": region.NetworkCount, + "vpc_count": region.VpcCount, + "zone_count": region.ZoneCount, + "progress": region.Progress, + "source": region.Source, + "can_delete": region.CanDelete, + "can_update": region.CanUpdate, + "is_emulated": region.IsEmulated, + "metadata": region.Metadata, + "created_at": region.CreatedAt, + "updated_at": region.UpdatedAt, + "imported_at": region.ImportedAt, + } + // 将区域信息添加到结果数组中 + formatted["cloudregions"] = append(formatted["cloudregions"].([]map[string]interface{}), regionInfo) + } + + // 构造摘要信息 + formatted["summary"] = map[string]interface{}{ + "total_cloudregions": response.Total, + "returned_count": len(response.Cloudregions), + "has_more": response.Total > int64(offset+len(response.Cloudregions)), + "next_offset": offset + len(response.Cloudregions), + } + + // 返回格式化后的结果 + return formatted +} + +// GetName 返回工具名称 +// 返回值: 工具名称字符串,用于唯一标识该工具 +func (c *CloudpodsRegionsTool) GetName() string { + return "cloudpods_list_regions" +} diff --git a/pkg/mcp-server/tools/cloudpods_server_create_tool.go b/pkg/mcp-server/tools/cloudpods_server_create_tool.go new file mode 100644 index 0000000000..129e85047c --- /dev/null +++ b/pkg/mcp-server/tools/cloudpods_server_create_tool.go @@ -0,0 +1,350 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// CloudpodsServerCreateTool 用于创建Cloudpods虚拟机实例 +type CloudpodsServerCreateTool struct { + // adapter 用于与 Cloudpods API 进行交互 + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsServerCreateTool 创建一个新的CloudpodsServerCreateTool实例 +// adapter: 用于与Cloudpods API交互的适配器 +// 返回值: CloudpodsServerCreateTool实例指针 +func NewCloudpodsServerCreateTool(adapter *adapters.CloudpodsAdapter) *CloudpodsServerCreateTool { + return &CloudpodsServerCreateTool{ + adapter: adapter, + } +} + +// GetTool 定义并返回创建虚拟机工具的元数据 +// 该工具用于创建Cloudpods虚拟机实例,支持指定各种配置参数 +// name: 虚拟机名称 (必填) +// vcpu_count: CPU核心数 (必填) +// vmem_size: 内存大小(MB) (必填) +// image_id: 镜像ID (必填) +// disk_size: 系统盘大小(GB),不指定则使用镜像默认大小 +// network_id: 网络ID (必填) +// serversku_id: 套餐ID,指定后将忽略vcpu_count和vmem_size参数 +// password: 虚拟机密码,长度8-30个字符 +// count: 创建数量,默认为1 +// auto_start: 是否自动启动,默认为true +// billing_type: 计费类型,例如:postpaid、prepaid +// duration: 包年包月时长,例如:1M、1Y +// description: 描述信息 +// hostname: 主机名 +// hypervisor: 虚拟化技术,如kvm, esxi等,默认为kvm +// metadata: 标签列表,格式为JSON字符串,例如:{"key1":"value1","key2":"value2"} +// secgroup_id: 安全组ID +// secgroups: 安全组ID列表,多个ID用逗号分隔 +// user_data: 用户自定义启动脚本 +// keypair_id: 秘钥对ID +// project_id: 项目ID +// zone_id: 可用区ID +// region_id: 区域ID +// disable_delete: 是否开启删除保护,默认为true +// boot_order: 启动顺序,如cdn +// data_disks: 数据盘配置,格式为JSON字符串数组,例如:[{"size":100,"disk_type":"data"}] +// ak: 用户登录cloudpods后获取的access key +// sk: 用户登录cloudpods后获取的secret key +func (c *CloudpodsServerCreateTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_create_server", + mcp.WithDescription("创建Cloudpods虚拟机实例"), + mcp.WithString("name", mcp.Required(), mcp.Description("虚拟机名称")), + mcp.WithString("vcpu_count", mcp.Required(), mcp.Description("CPU核心数")), + mcp.WithString("vmem_size", mcp.Required(), mcp.Description("内存大小(MB)")), + mcp.WithString("image_id", mcp.Required(), mcp.Description("镜像ID")), + mcp.WithString("disk_size", mcp.Description("系统盘大小(GB),不指定则使用镜像默认大小")), + mcp.WithString("network_id", mcp.Required(), mcp.Description("网络ID")), + mcp.WithString("serversku_id", mcp.Description("套餐ID,指定后将忽略vcpu_count和vmem_size参数")), + mcp.WithString("password", mcp.Description("虚拟机密码,长度8-30个字符")), + mcp.WithString("count", mcp.Description("创建数量,默认为1")), + mcp.WithString("auto_start", mcp.Description("是否自动启动,默认为true")), + mcp.WithString("billing_type", mcp.Description("计费类型,例如:postpaid、prepaid")), + mcp.WithString("duration", mcp.Description("包年包月时长,例如:1M、1Y")), + mcp.WithString("description", mcp.Description("描述信息")), + mcp.WithString("hostname", mcp.Description("主机名")), + mcp.WithString("hypervisor", mcp.Description("虚拟化技术,如kvm, esxi等,默认为kvm")), + mcp.WithString("metadata", mcp.Description("标签列表,格式为JSON字符串,例如:{\"key1\":\"value1\",\"key2\":\"value2\"}")), + mcp.WithString("secgroup_id", mcp.Description("安全组ID")), + mcp.WithString("secgroups", mcp.Description("安全组ID列表,多个ID用逗号分隔")), + mcp.WithString("user_data", mcp.Description("用户自定义启动脚本")), + mcp.WithString("keypair_id", mcp.Description("秘钥对ID")), + mcp.WithString("project_id", mcp.Description("项目ID")), + mcp.WithString("zone_id", mcp.Description("可用区ID")), + mcp.WithString("region_id", mcp.Description("区域ID")), + mcp.WithString("disable_delete", mcp.Description("是否开启删除保护,默认为true")), + mcp.WithString("boot_order", mcp.Description("启动顺序,如cdn")), + mcp.WithString("data_disks", mcp.Description("数据盘配置,格式为JSON字符串数组,例如:[{\"size\":100,\"disk_type\":\"data\"}]")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理创建虚拟机的请求 +// ctx: 上下文,用于控制请求的生命周期 +// req: 包含创建虚拟机所需参数的请求对象 +// 返回值: 包含创建结果的工具结果对象或错误信息 +func (c *CloudpodsServerCreateTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 获取必填参数:虚拟机名称 + name, err := req.RequireString("name") + if err != nil { + return nil, err + } + + // 获取必填参数:镜像ID + imageID, err := req.RequireString("image_id") + if err != nil { + return nil, err + } + + // 获取必填参数:网络ID + networkID, err := req.RequireString("network_id") + if err != nil { + return nil, err + } + + // 获取必填参数:CPU核心数并转换为整数 + vcpuCountStr, err := req.RequireString("vcpu_count") + if err != nil { + return nil, err + } + vcpuCount, err := strconv.ParseInt(vcpuCountStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("无效的CPU核心数: %s", vcpuCountStr) + } + + // 获取必填参数:内存大小并转换为整数 + vmemSizeStr, err := req.RequireString("vmem_size") + if err != nil { + return nil, err + } + vmemSize, err := strconv.ParseInt(vmemSizeStr, 10, 64) + if err != nil { + return nil, fmt.Errorf("无效的内存大小: %s", vmemSizeStr) + } + + // 获取可选参数:套餐ID + serverSkuID := req.GetString("serversku_id", "") + + // 获取可选参数:磁盘大小,如果指定则转换为整数 + diskSize := int64(0) + if diskSizeStr := req.GetString("disk_size", ""); diskSizeStr != "" { + if parsedSize, err := strconv.ParseInt(diskSizeStr, 10, 64); err == nil && parsedSize > 0 { + diskSize = parsedSize + } + } + + // 获取可选参数:虚拟机密码,并验证长度 + password := req.GetString("password", "") + if password != "" && (len(password) < 8 || len(password) > 30) { + return nil, fmt.Errorf("密码长度必须在8-30个字符之间") + } + + // 获取可选参数:创建数量,默认为1 + count := 1 + if countStr := req.GetString("count", "1"); countStr != "1" { + if parsedCount, err := strconv.Atoi(countStr); err == nil && parsedCount > 0 { + count = parsedCount + } + } + + // 获取可选参数:是否自动启动,默认为true + autoStart := true + if autoStartStr := req.GetString("auto_start", "true"); autoStartStr == "false" { + autoStart = false + } + + // 获取可选参数:是否开启删除保护,默认为true + disableDelete := true + if disableDeleteStr := req.GetString("disable_delete", "true"); disableDeleteStr == "false" { + disableDelete = false + } + + // 获取其他可选参数 + billingType := req.GetString("billing_type", "") + duration := req.GetString("duration", "") + description := req.GetString("description", "") + hostname := req.GetString("hostname", "") + hypervisor := req.GetString("hypervisor", "") + secgroupID := req.GetString("secgroup_id", "") + userData := req.GetString("user_data", "") + keypairID := req.GetString("keypair_id", "") + projectID := req.GetString("project_id", "") + zoneID := req.GetString("zone_id", "") + regionID := req.GetString("region_id", "") + bootOrder := req.GetString("boot_order", "") + + // 获取安全组ID列表,并按逗号分割 + var secgroups []string + if secgroupsStr := req.GetString("secgroups", ""); secgroupsStr != "" { + secgroups = strings.Split(secgroupsStr, ",") + } + + // 解析元数据JSON字符串 + metadata := make(map[string]string) + if metadataStr := req.GetString("metadata", ""); metadataStr != "" { + if err := json.Unmarshal([]byte(metadataStr), &metadata); err != nil { + return nil, fmt.Errorf("无效的元数据JSON格式: %w", err) + } + } + + // 解析数据盘配置JSON数组 + var dataDisks []models.DiskConfig + if dataDisksStr := req.GetString("data_disks", ""); dataDisksStr != "" { + if err := json.Unmarshal([]byte(dataDisksStr), &dataDisks); err != nil { + return nil, fmt.Errorf("无效的数据盘配置JSON格式: %w", err) + } + } + + // 构造创建虚拟机的请求对象 + createRequest := models.CreateServerRequest{ + Name: name, + VcpuCount: vcpuCount, + VmemSize: vmemSize, + ImageId: imageID, + DiskSize: diskSize, + NetworkId: networkID, + ServerskuId: serverSkuID, + Count: count, + Password: password, + AutoStart: autoStart, + BillingType: billingType, + Duration: duration, + Description: description, + Hostname: hostname, + Hypervisor: hypervisor, + Metadata: metadata, + SecgroupId: secgroupID, + Secgroups: secgroups, + UserData: userData, + KeypairId: keypairID, + ProjectId: projectID, + ZoneId: zoneID, + RegionId: regionID, + DisableDelete: disableDelete, + BootOrder: bootOrder, + DataDisks: dataDisks, + } + + // 获取访问凭证 + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器创建虚拟机 + response, err := c.adapter.CreateServer(ctx, createRequest, ak, sk) + if err != nil { + log.Errorf("Fail to create server: %s", err) + return nil, fmt.Errorf("fail to create server: %w", err) + } + + // 格式化创建结果 + formattedResult := c.formatCreateResult(response, &createRequest) + + // 将结果序列化为JSON格式 + resultJSON, err := json.MarshalIndent(formattedResult, "", " ") + if err != nil { + log.Errorf("Fail to serialize result: %s", err) + return nil, fmt.Errorf("fail to serialize result: %w", err) + } + + // 返回格式化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回工具的名称标识符 +// 返回值: 工具名称字符串,用于唯一标识该工具 +func (c *CloudpodsServerCreateTool) GetName() string { + return "cloudpods_create_server" +} + +// formatCreateResult 格式化创建虚拟机的响应结果 +// response: 原始的创建虚拟机响应数据 +// request: 原始的创建虚拟机请求数据 +// 返回值: 格式化后的结果,包含创建信息、结果详情和摘要 +func (c *CloudpodsServerCreateTool) formatCreateResult(response *models.CreateServerResponse, request *models.CreateServerRequest) map[string]interface{} { + // 初始化格式化结果结构 + formatted := map[string]interface{}{ + // 创建请求的基本信息 + "create_info": map[string]interface{}{ + "name": request.Name, + "vcpu_count": request.VcpuCount, + "vmem_size": request.VmemSize, + "image_id": request.ImageId, + "disk_size": request.DiskSize, + "network_id": request.NetworkId, + "serversku_id": request.ServerskuId, + "count": request.Count, + "auto_start": request.AutoStart, + "billing_type": request.BillingType, + "duration": request.Duration, + "description": request.Description, + "hostname": request.Hostname, + "hypervisor": request.Hypervisor, + "secgroup_id": request.SecgroupId, + "keypair_id": request.KeypairId, + "project_id": request.ProjectId, + "zone_id": request.ZoneId, + "region_id": request.RegionId, + "disable_delete": request.DisableDelete, + "boot_order": request.BootOrder, + }, + // 创建响应的结果信息 + "result": map[string]interface{}{ + "status": response.Status, + "message": response.Message, + "servers": make([]map[string]interface{}, 0, len(response.Data.Servers)), + }, + } + + // 遍历创建的虚拟机列表,构造每个虚拟机的详细信息 + for _, server := range response.Data.Servers { + serverInfo := map[string]interface{}{ + "id": server.ID, + "name": server.Name, + "status": server.Status, + "task_id": server.TaskID, + } + formatted["result"].(map[string]interface{})["servers"] = append( + formatted["result"].(map[string]interface{})["servers"].([]map[string]interface{}), + serverInfo, + ) + } + + // 构造摘要信息 + formatted["summary"] = map[string]interface{}{ + "requested_count": request.Count, // 请求创建的虚拟机数量 + "created_count": len(response.Data.Servers), // 实际创建的虚拟机数量 + "success": response.Status == 200, // 创建是否成功 + } + + return formatted +} diff --git a/pkg/mcp-server/tools/cloudpods_server_monitor_tool.go b/pkg/mcp-server/tools/cloudpods_server_monitor_tool.go new file mode 100644 index 0000000000..d2561cdea9 --- /dev/null +++ b/pkg/mcp-server/tools/cloudpods_server_monitor_tool.go @@ -0,0 +1,366 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "github.com/mark3labs/mcp-go/mcp" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// CloudpodsServerMonitorTool 用于获取Cloudpods虚拟机监控信息 +type CloudpodsServerMonitorTool struct { + // adapter 用于与 Cloudpods API 进行交互 + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsServerMonitorTool 创建一个新的CloudpodsServerMonitorTool实例 +// adapter: 用于与Cloudpods API交互的适配器 +// 返回值: CloudpodsServerMonitorTool实例指针 +func NewCloudpodsServerMonitorTool(adapter *adapters.CloudpodsAdapter) *CloudpodsServerMonitorTool { + return &CloudpodsServerMonitorTool{ + adapter: adapter, + } +} + +// GetTool 定义并返回获取虚拟机监控信息工具的元数据 +// 该工具用于获取Cloudpods虚拟机的监控信息,包括CPU、内存、磁盘、网络等指标 +// server_id: 虚拟机ID (必填) +// start_time: 开始时间戳(秒),默认为1小时前 +// end_time: 结束时间戳(秒),默认为当前时间 +// metrics: 监控指标,多个用逗号分隔,例如:cpu_usage,mem_usage,disk_usage,net_bps_rx,net_bps_tx +// ak: 用户登录cloudpods后获取的access key +// sk: 用户登录cloudpods后获取的secret key +func (c *CloudpodsServerMonitorTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_get_server_monitor", + mcp.WithDescription("获取Cloudpods虚拟机监控信息,包括CPU、内存、磁盘、网络等指标"), + mcp.WithString("server_id", mcp.Required(), mcp.Description("虚拟机ID")), + mcp.WithString("start_time", mcp.Description("开始时间戳(秒),默认为1小时前")), + mcp.WithString("end_time", mcp.Description("结束时间戳(秒),默认为当前时间")), + mcp.WithString("metrics", mcp.Description("监控指标,多个用逗号分隔,例如:cpu_usage,mem_usage,disk_usage,net_bps_rx,net_bps_tx")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理获取虚拟机监控信息的请求 +// ctx: 控制生命周期的上下文 +// req: 包含获取监控信息所需参数的请求对象 +// 返回值: 包含监控信息的响应对象和可能的错误 +func (c *CloudpodsServerMonitorTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 获取必填参数:虚拟机ID + serverID, err := req.RequireString("server_id") + if err != nil { + return nil, err + } + + // 设置默认时间范围:结束时间为当前时间,开始时间为1小时前 + now := time.Now().Unix() + startTime := now - 3600 + + // 解析开始时间参数,如果指定则使用指定值 + if startTimeStr := req.GetString("start_time", ""); startTimeStr != "" { + if parsedStartTime, err := strconv.ParseInt(startTimeStr, 10, 64); err == nil { + startTime = parsedStartTime + } + } + + // 解析结束时间参数,如果指定则使用指定值 + endTime := now + if endTimeStr := req.GetString("end_time", ""); endTimeStr != "" { + if parsedEndTime, err := strconv.ParseInt(endTimeStr, 10, 64); err == nil { + endTime = parsedEndTime + } + } + + // 获取可选参数:监控指标 + var metrics []string + if metricsStr := req.GetString("metrics", ""); metricsStr != "" { + metrics = strings.Split(metricsStr, ",") + for i, metric := range metrics { + metrics[i] = strings.TrimSpace(metric) + } + } else { + metrics = []string{"cpu_usage", "mem_usage", "disk_usage", "net_bps_rx", "net_bps_tx"} + } + + // 获取ak和sk参数,用于认证 + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器获取虚拟机监控信息 + monitorResponse, err := c.adapter.GetServerMonitor(ctx, serverID, startTime, endTime, metrics, ak, sk) + if err != nil { + log.Errorf("Fail to get server monitor: %s", err) + return nil, fmt.Errorf("fail to get server monitor: %w", err) + } + + // 格式化监控结果 + formattedResult := c.formatMonitorResult(monitorResponse, serverID, startTime, endTime, metrics) + + // 将结果序列化为JSON格式 + resultJSON, err := json.MarshalIndent(formattedResult, "", " ") + if err != nil { + log.Errorf("Fail to serialize result: %s", err) + return nil, fmt.Errorf("fail to serialize result: %w", err) + } + + // 返回格式化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回工具的名称标识符 +// 返回值: 工具名称字符串,用于唯一标识该工具 +func (c *CloudpodsServerMonitorTool) GetName() string { + return "cloudpods_get_server_monitor" +} + +// formatMonitorResult 格式化虚拟机监控信息的响应结果 +// response: 原始监控响应数据 +// serverID: 虚拟机ID +// startTime: 监控开始时间 +// endTime: 监控结束时间 +// requestedMetrics: 请求的监控指标 +// 返回值: 包含监控信息的格式化结果 +func (c *CloudpodsServerMonitorTool) formatMonitorResult(response *models.MonitorResponse, serverID string, startTime, endTime int64, requestedMetrics []string) map[string]interface{} { + // 初始化格式化结果结构 + formatted := map[string]interface{}{ + // 添加请求的基本信息 + "query_info": map[string]interface{}{ + "server_id": serverID, + "start_time": startTime, + "end_time": endTime, + "start_time_human": time.Unix(startTime, 0).Format("2006-01-02 15:04:05"), + "end_time_human": time.Unix(endTime, 0).Format("2006-01-02 15:04:05"), + "requested_metrics": requestedMetrics, + "duration_seconds": endTime - startTime, + }, + "status": response.Status, + "metrics": make([]map[string]interface{}, 0, len(response.Data.Metrics)), + } + + for _, metric := range response.Data.Metrics { + metricInfo := map[string]interface{}{ + "metric": metric.Metric, + "unit": metric.Unit, + "data_points": len(metric.Values), + "values": make([]map[string]interface{}, 0, len(metric.Values)), + } + + var totalValue float64 + var minValue, maxValue float64 + var latestValue float64 + var latestTime int64 + + for i, value := range metric.Values { + valueInfo := map[string]interface{}{ + "timestamp": value.Timestamp, + "time_human": time.Unix(value.Timestamp, 0).Format("2006-01-02 15:04:05"), + "value": value.Value, + } + metricInfo["values"] = append(metricInfo["values"].([]map[string]interface{}), valueInfo) + + totalValue += value.Value + if i == 0 { + minValue = value.Value + maxValue = value.Value + } else { + if value.Value < minValue { + minValue = value.Value + } + if value.Value > maxValue { + maxValue = value.Value + } + } + + if value.Timestamp > latestTime { + latestTime = value.Timestamp + latestValue = value.Value + } + } + + if len(metric.Values) > 0 { + metricInfo["statistics"] = map[string]interface{}{ + "min": minValue, + "max": maxValue, + "average": totalValue / float64(len(metric.Values)), + "latest": latestValue, + } + } + + formatted["metrics"] = append(formatted["metrics"].([]map[string]interface{}), metricInfo) + } + + formatted["summary"] = map[string]interface{}{ + "total_metrics": len(response.Data.Metrics), + "query_successful": response.Status == 200, + "time_range_hours": float64(endTime-startTime) / 3600, + } + + return formatted +} + +// CloudpodsServerStatsTool 用于获取Cloudpods虚拟机实时统计信息 +type CloudpodsServerStatsTool struct { + // adapter 用于与 Cloudpods API 进行交互 + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsServerStatsTool 创建一个新的CloudpodsServerStatsTool实例 +// adapter: 用于与Cloudpods API交互的适配器 +// 返回值: CloudpodsServerStatsTool实例指针 +func NewCloudpodsServerStatsTool(adapter *adapters.CloudpodsAdapter) *CloudpodsServerStatsTool { + return &CloudpodsServerStatsTool{ + adapter: adapter, + } +} + +// GetTool 定义并返回获取虚拟机统计信息工具的元数据 +// 该工具用于获取Cloudpods虚拟机的实时统计信息,包括CPU使用率、内存使用率、磁盘使用率和网络流量 +// server_id: 虚拟机ID (必填) +// ak: 用户登录cloudpods后获取的access key +// sk: 用户登录cloudpods后获取的secret key +func (c *CloudpodsServerStatsTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_get_server_stats", + mcp.WithDescription("获取Cloudpods虚拟机实时统计信息,包括CPU使用率、内存使用率、磁盘使用率和网络流量"), + mcp.WithString("server_id", mcp.Required(), mcp.Description("虚拟机ID")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理获取虚拟机统计信息的请求 +// ctx: 控制生命周期的上下文 +// req: 包含获取统计信息所需参数的请求对象 +// 返回值: 包含统计信息的响应对象和可能的错误 +func (c *CloudpodsServerStatsTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 获取必填参数:虚拟机ID + serverID, err := req.RequireString("server_id") + if err != nil { + return nil, err + } + + // 获取可选参数:访问凭证 + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器获取虚拟机统计信息 + statsResponse, err := c.adapter.GetServerStats(ctx, serverID, ak, sk) + if err != nil { + log.Errorf("Fail to get server stats: %s", err) + return nil, fmt.Errorf("fail to get server stats: %w", err) + } + + // 格式化统计结果 + formattedResult := c.formatStatsResult(statsResponse, serverID) + + // 将结果序列化为JSON格式 + resultJSON, err := json.MarshalIndent(formattedResult, "", " ") + if err != nil { + log.Errorf("Fail to serialize result: %s", err) + return nil, fmt.Errorf("fail to serialize result: %w", err) + } + + // 返回格式化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回工具的名称标识符 +// 返回值: 工具名称字符串,用于唯一标识该工具 +func (c *CloudpodsServerStatsTool) GetName() string { + return "cloudpods_get_server_stats" +} + +// formatStatsResult 格式化虚拟机统计信息的响应结果 +// response: 原始统计响应数据 +// serverID: 虚拟机ID +// 返回值: 包含统计信息的格式化结果 +func (c *CloudpodsServerStatsTool) formatStatsResult(response *models.ServerStatsResponse, serverID string) map[string]interface{} { + // 初始化格式化结果结构 + formatted := map[string]interface{}{ + "server_id": serverID, + "status": response.Status, + // 添加统计信息 + "stats": map[string]interface{}{ + "cpu_usage": fmt.Sprintf("%.1f%%", response.Data.CPUUsage), + "memory_usage": fmt.Sprintf("%.1f%%", response.Data.MemUsage), + "disk_usage": fmt.Sprintf("%.1f%%", response.Data.DiskUsage), + "network": map[string]interface{}{ + "receive_bps": response.Data.NetBpsRx, + "transmit_bps": response.Data.NetBpsTx, + "receive_mbps": fmt.Sprintf("%.2f Mbps", float64(response.Data.NetBpsRx)/(1024*1024)), + "transmit_mbps": fmt.Sprintf("%.2f Mbps", float64(response.Data.NetBpsTx)/(1024*1024)), + }, + "updated_at": response.Data.UpdatedAt, + }, + // 添加原始数据 + "raw_data": map[string]interface{}{ + "cpu_usage": response.Data.CPUUsage, + "mem_usage": response.Data.MemUsage, + "disk_usage": response.Data.DiskUsage, + "net_bps_rx": response.Data.NetBpsRx, + "net_bps_tx": response.Data.NetBpsTx, + }, + } + + // 评估虚拟机健康状态 + var healthStatus string + var healthScore int + + if response.Data.CPUUsage > 90 || response.Data.MemUsage > 90 || response.Data.DiskUsage > 90 { + healthStatus = "警告" + healthScore = 1 + } else if response.Data.CPUUsage > 70 || response.Data.MemUsage > 70 || response.Data.DiskUsage > 80 { + healthStatus = "注意" + healthScore = 2 + } else { + healthStatus = "正常" + healthScore = 3 + } + + // 添加健康状态信息 + formatted["health"] = map[string]interface{}{ + "status": healthStatus, + "score": healthScore, + "notes": []string{}, + } + + // 添加健康状态建议 + notes := []string{} + if response.Data.CPUUsage > 90 { + notes = append(notes, "CPU使用率过高,建议检查系统负载") + } + if response.Data.MemUsage > 90 { + notes = append(notes, "内存使用率过高,建议释放内存或增加内存") + } + if response.Data.DiskUsage > 90 { + notes = append(notes, "磁盘使用率过高,建议清理磁盘空间") + } + formatted["health"].(map[string]interface{})["notes"] = notes + + return formatted +} diff --git a/pkg/mcp-server/tools/cloudpods_server_operations_tool.go b/pkg/mcp-server/tools/cloudpods_server_operations_tool.go new file mode 100644 index 0000000000..22736713fe --- /dev/null +++ b/pkg/mcp-server/tools/cloudpods_server_operations_tool.go @@ -0,0 +1,524 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + + "github.com/mark3labs/mcp-go/mcp" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// CloudpodsServerStartTool 用于启动指定的Cloudpods虚拟机实例 +type CloudpodsServerStartTool struct { + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsServerStartTool 创建一个新的CloudpodsServerStartTool实例 +func NewCloudpodsServerStartTool(adapter *adapters.CloudpodsAdapter) *CloudpodsServerStartTool { + return &CloudpodsServerStartTool{ + adapter: adapter, + } +} + +// GetTool 返回启动虚拟机工具的定义,包括参数和描述 +func (c *CloudpodsServerStartTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_start_server", + mcp.WithDescription("启动指定的Cloudpods虚拟机实例"), + mcp.WithString("server_id", mcp.Required(), mcp.Description("虚拟机ID")), + mcp.WithString("auto_prepaid", mcp.Description("按量机器自动转换为包年包月,默认为false")), + mcp.WithString("qemu_version", mcp.Description("指定启动虚拟机的Qemu版本,可选值:2.12.1, 4.2.0,仅适用于KVM虚拟机")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理启动虚拟机的请求,调用适配器执行启动操作并返回结果 +func (c *CloudpodsServerStartTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 从请求中获取必需的 server_id 参数 + serverID, err := req.RequireString("server_id") + if err != nil { + return nil, err + } + + // 解析 auto_prepaid 参数,决定是否自动转换为包年包月 + autoPrepaid := false + if autoPrepaidStr := req.GetString("auto_prepaid", "false"); autoPrepaidStr == "true" { + autoPrepaid = true + } + + // 获取 qemu_version 参数,用于指定启动虚拟机的 Qemu 版本 + qemuVersion := req.GetString("qemu_version", "") + + // 构造启动虚拟机的请求参数 + startReq := models.ServerStartRequest{ + AutoPrepaid: autoPrepaid, + QemuVersion: qemuVersion, + } + + // 获取认证所需的 access key 和 secret key + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器的 StartServer 方法执行启动操作 + response, err := c.adapter.StartServer(ctx, serverID, startReq, ak, sk) + if err != nil { + log.Errorf("Fail to start server: %s", err) + return nil, fmt.Errorf("fail to start server: %w", err) + } + + // 构造返回结果,包含任务ID、成功状态和状态信息 + result := map[string]interface{}{ + "server_id": serverID, + "operation": "start", + "task_id": response.TaskId, + "success": response.Success, + "status": response.Status, + } + + // 如果有错误信息,则添加到结果中 + if response.Error != "" { + result["error"] = response.Error + } + + // 将结果序列化为 JSON 格式 + resultJSON, err := json.MarshalIndent(result, "", " ") + if err != nil { + return nil, fmt.Errorf("序列化结果失败: %w", err) + } + + // 返回序列化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回启动虚拟机工具的名称 +func (c *CloudpodsServerStartTool) GetName() string { + return "cloudpods_start_server" +} + +// CloudpodsServerStopTool 用于停止指定的Cloudpods虚拟机实例 +type CloudpodsServerStopTool struct { + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsServerStopTool 创建一个新的CloudpodsServerStopTool实例 +func NewCloudpodsServerStopTool(adapter *adapters.CloudpodsAdapter) *CloudpodsServerStopTool { + return &CloudpodsServerStopTool{ + adapter: adapter, + } +} + +// GetTool 返回停止虚拟机工具的定义,包括参数和描述 +func (c *CloudpodsServerStopTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_stop_server", + mcp.WithDescription("停止指定的Cloudpods虚拟机实例"), + mcp.WithString("server_id", mcp.Required(), mcp.Description("虚拟机ID")), + mcp.WithString("is_force", mcp.Description("是否强制停止,默认为false")), + mcp.WithString("stop_charging", mcp.Description("是否关机停止计费,默认为false")), + mcp.WithString("timeout_secs", mcp.Description("关机等待时间,如果是强制关机,则等待时间为0,如果不设置,默认为30秒")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理停止虚拟机的请求,调用适配器执行停止操作并返回结果 +func (c *CloudpodsServerStopTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 从请求中获取必需的 server_id 参数 + serverID, err := req.RequireString("server_id") + if err != nil { + return nil, err + } + + // 解析 is_force 参数,决定是否强制停止虚拟机 + isForce := false + if isForceStr := req.GetString("is_force", "false"); isForceStr == "true" { + isForce = true + } + + // 解析 stop_charging 参数,决定是否停止计费 + stopCharging := false + if stopChargingStr := req.GetString("stop_charging", "false"); stopChargingStr == "true" { + stopCharging = true + } + + // 解析 timeout_secs 参数,设置停止操作的超时时间 + timeoutSecs := int64(0) + if timeoutSecsStr := req.GetString("timeout_secs", ""); timeoutSecsStr != "" { + if parsed, err := strconv.ParseInt(timeoutSecsStr, 10, 64); err == nil && parsed > 0 { + timeoutSecs = parsed + } + } + + // 构造停止虚拟机的请求参数 + stopReq := models.ServerStopRequest{ + IsForce: isForce, + StopCharging: stopCharging, + TimeoutSecs: timeoutSecs, + } + + // 获取认证所需的 access key 和 secret key + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器的 StopServer 方法执行停止操作 + response, err := c.adapter.StopServer(ctx, serverID, stopReq, ak, sk) + if err != nil { + log.Errorf("Fail to stop server: %s", err) + return nil, fmt.Errorf("fail to stop server: %w", err) + } + + // 构造返回结果,包含任务ID、成功状态和状态信息 + result := map[string]interface{}{ + "server_id": serverID, + "operation": "stop", + "task_id": response.TaskId, + "success": response.Success, + "status": response.Status, + } + + // 如果有错误信息,则添加到结果中 + if response.Error != "" { + result["error"] = response.Error + } + + // 将结果序列化为 JSON 格式 + resultJSON, err := json.MarshalIndent(result, "", " ") + if err != nil { + return nil, fmt.Errorf("序列化结果失败: %w", err) + } + + // 返回序列化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回停止虚拟机工具的名称 +func (c *CloudpodsServerStopTool) GetName() string { + return "cloudpods_stop_server" +} + +// CloudpodsServerRestartTool 用于重启指定的Cloudpods虚拟机实例 +type CloudpodsServerRestartTool struct { + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsServerRestartTool 创建一个新的CloudpodsServerRestartTool实例 +func NewCloudpodsServerRestartTool(adapter *adapters.CloudpodsAdapter) *CloudpodsServerRestartTool { + return &CloudpodsServerRestartTool{ + adapter: adapter, + } +} + +// GetTool 返回重启虚拟机工具的定义,包括参数和描述 +func (c *CloudpodsServerRestartTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_restart_server", + mcp.WithDescription("重启指定的Cloudpods虚拟机实例"), + mcp.WithString("server_id", mcp.Required(), mcp.Description("虚拟机ID")), + mcp.WithString("is_force", mcp.Description("是否强制重启,默认为false")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理重启虚拟机的请求,调用适配器执行重启操作并返回结果 +func (c *CloudpodsServerRestartTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 从请求中获取必需的 server_id 参数 + serverID, err := req.RequireString("server_id") + if err != nil { + return nil, err + } + + // 解析 is_force 参数,决定是否强制重启虚拟机 + isForce := false + if isForceStr := req.GetString("is_force", "false"); isForceStr == "true" { + isForce = true + } + + // 构造重启虚拟机的请求参数 + restartReq := models.ServerRestartRequest{ + IsForce: isForce, + } + + // 获取认证所需的 access key 和 secret key + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器的 RestartServer 方法执行重启操作 + response, err := c.adapter.RestartServer(ctx, serverID, restartReq, ak, sk) + if err != nil { + log.Errorf("Fail to query restart server: %s", err) + return nil, fmt.Errorf("fail to restart server: %w", err) + } + + // 构造返回结果,包含任务ID、成功状态和状态信息 + result := map[string]interface{}{ + "server_id": serverID, + "operation": "restart", + "task_id": response.TaskId, + "success": response.Success, + "status": response.Status, + } + + // 如果有错误信息,则添加到结果中 + if response.Error != "" { + result["error"] = response.Error + } + + // 将结果序列化为 JSON 格式 + resultJSON, err := json.MarshalIndent(result, "", " ") + if err != nil { + return nil, fmt.Errorf("序列化结果失败: %w", err) + } + + // 返回序列化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回重启虚拟机工具的名称 +func (c *CloudpodsServerRestartTool) GetName() string { + return "cloudpods_restart_server" +} + +// CloudpodsServerResetPasswordTool 用于重置指定Cloudpods虚拟机的登录密码 +type CloudpodsServerResetPasswordTool struct { + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsServerResetPasswordTool 创建一个新的CloudpodsServerResetPasswordTool实例 +func NewCloudpodsServerResetPasswordTool(adapter *adapters.CloudpodsAdapter) *CloudpodsServerResetPasswordTool { + return &CloudpodsServerResetPasswordTool{ + adapter: adapter, + } +} + +// GetTool 返回重置虚拟机密码工具的定义,包括参数和描述 +func (c *CloudpodsServerResetPasswordTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_reset_server_password", + mcp.WithDescription("重置指定Cloudpods虚拟机的登录密码"), + mcp.WithString("server_id", mcp.Required(), mcp.Description("虚拟机ID")), + mcp.WithString("password", mcp.Required(), mcp.Description("新密码,长度8-30个字符")), + mcp.WithString("reset_password", mcp.Description("是否重置密码,默认为true")), + mcp.WithString("auto_start", mcp.Description("重置后是否自动启动,默认为true")), + mcp.WithString("username", mcp.Description("用户名,可选,默认为空")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理重置虚拟机密码的请求,调用适配器执行密码重置操作并返回结果 +func (c *CloudpodsServerResetPasswordTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 从请求中获取必需的 server_id 参数 + serverID, err := req.RequireString("server_id") + if err != nil { + return nil, err + } + + // 从请求中获取必需的 password 参数,并验证其长度 + password, err := req.RequireString("password") + if err != nil { + return nil, err + } + + if len(password) < 8 || len(password) > 30 { + return nil, fmt.Errorf("密码长度必须在8-30个字符之间") + } + + // 解析 reset_password 参数,决定是否重置密码 + resetPassword := true + if resetPasswordStr := req.GetString("reset_password", "true"); resetPasswordStr == "false" { + resetPassword = false + } + + // 解析 auto_start 参数,决定重置密码后是否自动启动虚拟机 + autoStart := true + if autoStartStr := req.GetString("auto_start", "true"); autoStartStr == "false" { + autoStart = false + } + + // 获取 username 参数,可选 + username := req.GetString("username", "") + + // 构造重置虚拟机密码的请求参数 + resetPasswordReq := models.ServerResetPasswordRequest{ + Password: password, + ResetPassword: resetPassword, + AutoStart: autoStart, + Username: username, + } + + // 获取认证所需的 access key 和 secret key + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器的 ResetServerPassword 方法执行密码重置操作 + response, err := c.adapter.ResetServerPassword(ctx, serverID, resetPasswordReq, ak, sk) + if err != nil { + log.Errorf("Fail to reset server password: %s", err) + return nil, fmt.Errorf("fail to reset server password: %w", err) + } + + // 构造返回结果,包含任务ID、成功状态和状态信息 + result := map[string]interface{}{ + "server_id": serverID, + "operation": "reset-password", + "task_id": response.TaskId, + "success": response.Success, + "status": response.Status, + } + + // 如果有错误信息,则添加到结果中 + if response.Error != "" { + result["error"] = response.Error + } + + // 将结果序列化为 JSON 格式 + resultJSON, err := json.MarshalIndent(result, "", " ") + if err != nil { + return nil, fmt.Errorf("序列化结果失败: %w", err) + } + + // 返回序列化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回重置虚拟机密码工具的名称 +func (c *CloudpodsServerResetPasswordTool) GetName() string { + return "cloudpods_reset_server_password" +} + +// CloudpodsServerDeleteTool 用于删除指定的Cloudpods虚拟机实例 +type CloudpodsServerDeleteTool struct { + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsServerDeleteTool 创建一个新的CloudpodsServerDeleteTool实例 +func NewCloudpodsServerDeleteTool(adapter *adapters.CloudpodsAdapter) *CloudpodsServerDeleteTool { + return &CloudpodsServerDeleteTool{ + adapter: adapter, + } +} + +// GetTool 返回删除虚拟机工具的定义,包括参数和描述 +func (c *CloudpodsServerDeleteTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_delete_server", + mcp.WithDescription("删除指定的Cloudpods虚拟机实例"), + mcp.WithString("server_id", mcp.Required(), mcp.Description("虚拟机ID")), + mcp.WithString("override_pending_delete", mcp.Description("是否强制删除(包括在回收站中的实例),默认为false")), + mcp.WithString("purge", mcp.Description("是否仅删除本地资源,默认为false")), + mcp.WithString("delete_snapshots", mcp.Description("是否删除快照,默认为false")), + mcp.WithString("delete_eip", mcp.Description("是否删除关联的EIP,默认为false")), + mcp.WithString("delete_disks", mcp.Description("是否删除关联的数据盘,默认为false")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理删除虚拟机的请求,调用适配器执行删除操作并返回结果 +func (c *CloudpodsServerDeleteTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 从请求中获取必需的 server_id 参数 + serverID, err := req.RequireString("server_id") + if err != nil { + return nil, err + } + + // 解析 override_pending_delete 参数,决定是否强制删除(包括在回收站中的实例) + overridePendingDelete := false + if overrideStr := req.GetString("override_pending_delete", "false"); overrideStr == "true" { + overridePendingDelete = true + } + + // 解析 purge 参数,决定是否仅删除本地资源 + purge := false + if purgeStr := req.GetString("purge", "false"); purgeStr == "true" { + purge = true + } + + // 解析 delete_snapshots 参数,决定是否删除快照 + deleteSnapshots := false + if deleteSnapshotsStr := req.GetString("delete_snapshots", "false"); deleteSnapshotsStr == "true" { + deleteSnapshots = true + } + + // 解析 delete_eip 参数,决定是否删除关联的EIP + deleteEip := false + if deleteEipStr := req.GetString("delete_eip", "false"); deleteEipStr == "true" { + deleteEip = true + } + + // 解析 delete_disks 参数,决定是否删除关联的数据盘 + deleteDisks := false + if deleteDisksStr := req.GetString("delete_disks", "false"); deleteDisksStr == "true" { + deleteDisks = true + } + + // 构造删除虚拟机的请求参数 + deleteReq := models.ServerDeleteRequest{ + OverridePendingDelete: overridePendingDelete, + Purge: purge, + DeleteSnapshots: deleteSnapshots, + DeleteEip: deleteEip, + DeleteDisks: deleteDisks, + } + + // 获取认证所需的 access key 和 secret key + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器的 DeleteServer 方法执行删除操作 + response, err := c.adapter.DeleteServer(ctx, serverID, deleteReq, ak, sk) + if err != nil { + log.Errorf("Fail to delete server: %s", err) + return nil, fmt.Errorf("fail to delete server: %w", err) + } + + // 构造返回结果,包含任务ID、成功状态和状态信息 + result := map[string]interface{}{ + "server_id": serverID, + "operation": "delete", + "task_id": response.TaskId, + "success": response.Success, + "status": response.Status, + } + + // 如果有错误信息,则添加到结果中 + if response.Error != "" { + result["error"] = response.Error + } + + // 将结果序列化为 JSON 格式 + resultJSON, err := json.MarshalIndent(result, "", " ") + if err != nil { + return nil, fmt.Errorf("序列化结果失败: %w", err) + } + + // 返回序列化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回删除虚拟机工具的名称 +func (c *CloudpodsServerDeleteTool) GetName() string { + return "cloudpods_delete_server" +} diff --git a/pkg/mcp-server/tools/cloudpods_servers_tool.go b/pkg/mcp-server/tools/cloudpods_servers_tool.go new file mode 100644 index 0000000000..abd8c1fa03 --- /dev/null +++ b/pkg/mcp-server/tools/cloudpods_servers_tool.go @@ -0,0 +1,176 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + + "github.com/mark3labs/mcp-go/mcp" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// CloudpodsServersTool 是用于查询 Cloudpods 虚拟机实例列表的工具 +type CloudpodsServersTool struct { + // adapter 用于与 Cloudpods API 进行交互 + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsServersTool 创建一个新的 Cloudpods 虚拟机查询工具 +// adapter: 用于与Cloudpods API交互的适配器 +// 返回值: CloudpodsServersTool实例指针 +func NewCloudpodsServersTool(adapter *adapters.CloudpodsAdapter) *CloudpodsServersTool { + return &CloudpodsServersTool{ + adapter: adapter, + } +} + +// GetTool 定义并返回查询虚拟机实例列表工具的元数据 +// 该工具用于查询Cloudpods虚拟机实例列表,获取虚拟机信息 +// limit: 返回结果数量限制,默认为50 +// offset: 结果偏移量,默认为0 +// search: 按名称或ID模糊搜索 +// status: 虚拟机状态,例如:running、stopped、creating等 +// ak: 用户登录cloudpods后获取的access key +// sk: 用户登录cloudpods后获取的secret key +func (c *CloudpodsServersTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_list_servers", + mcp.WithDescription("查询Cloudpods虚拟机实例列表,获取虚拟机信息"), + mcp.WithString("limit", mcp.Description("返回结果数量限制,默认为50")), + mcp.WithString("offset", mcp.Description("结果偏移量,默认为0")), + mcp.WithString("search", mcp.Description("按名称或ID模糊搜索")), + mcp.WithString("status", mcp.Description("虚拟机状态,例如:running、stopped、creating等")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理查询 Cloudpods 虚拟机实例列表的请求 +// ctx: 控制生命周期的上下文 +// req: 包含查询参数的请求对象 +// 返回值: 包含虚拟机列表的响应对象和可能的错误 +func (c *CloudpodsServersTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 获取可选参数:返回结果数量限制,如果指定则转换为整数 + limit := 50 + if limitStr := req.GetString("limit", ""); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + // 获取可选参数:结果偏移量,如果指定则转换为整数 + offset := 0 + if offsetStr := req.GetString("offset", ""); offsetStr != "" { + if parsedOffset, err := strconv.Atoi(offsetStr); err == nil && parsedOffset >= 0 { + offset = parsedOffset + } + } + + // 获取可选参数:搜索关键词和虚拟机状态 + search := req.GetString("search", "") + status := req.GetString("status", "") + + // 获取可选参数:访问凭证 + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器查询虚拟机列表 + serversResponse, err := c.adapter.ListServers(ctx, limit, offset, search, status, ak, sk) + if err != nil { + log.Errorf("Fail to query server: %s", err) + return nil, fmt.Errorf("fail to query server: %w", err) + } + + // 格式化查询结果 + formattedResult := c.formatServersResult(serversResponse, limit, offset, search, status) + + // 将结果序列化为JSON格式 + resultJSON, err := json.MarshalIndent(formattedResult, "", " ") + if err != nil { + log.Errorf("Fail to serialize result: %s", err) + return nil, fmt.Errorf("fail to serialize result: %w", err) + } + + // 返回格式化后的结果 + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回工具的名称标识符 +// 返回值: 工具名称字符串,用于唯一标识该工具 +func (c *CloudpodsServersTool) GetName() string { + return "cloudpods_list_servers" +} + +// formatServersResult 格式化虚拟机实例列表查询结果 +// response: 原始虚拟机列表响应数据 +// limit: 查询限制数量 +// offset: 查询偏移量 +// search: 搜索关键词 +// status: 虚拟机状态 +// 返回值: 包含虚拟机列表的格式化结果 +func (c *CloudpodsServersTool) formatServersResult(response *models.ServerListResponse, limit int, offset int, search string, status string) map[string]interface{} { + // 初始化格式化结果结构 + formatted := map[string]interface{}{ + // 添加查询信息 + "query_info": map[string]interface{}{ + "limit": limit, + "offset": offset, + "search": search, + "status": status, + "total": response.Total, + "count": len(response.Servers), + }, + // 初始化虚拟机列表 + "servers": make([]map[string]interface{}, 0, len(response.Servers)), + } + + // 遍历虚拟机列表,构造每个虚拟机的详细信息 + for _, server := range response.Servers { + // 将内存大小从MB转换为GB + memoryGB := float64(server.VmemSize) / 1024 + + // 构造虚拟机信息 + serverInfo := map[string]interface{}{ + "id": server.Id, + "name": server.Name, + "status": server.Status, + "vcpu_count": server.VcpuCount, + "vmem_size": server.VmemSize, + "memory_gb": fmt.Sprintf("%.1f GB", memoryGB), + "os_name": server.OsName, + "ips": server.Ips, + "host": server.Host, + "zone": server.Zone, + "region": server.Cloudregion, + "created_at": server.CreatedAt, + } + formatted["servers"] = append(formatted["servers"].([]map[string]interface{}), serverInfo) + } + + // 构造摘要信息 + formatted["summary"] = map[string]interface{}{ + "total_servers": response.Total, + "returned_count": len(response.Servers), + } + + return formatted +} diff --git a/pkg/mcp-server/tools/cloudpods_serverskus_tool.go b/pkg/mcp-server/tools/cloudpods_serverskus_tool.go new file mode 100644 index 0000000000..e2542be18b --- /dev/null +++ b/pkg/mcp-server/tools/cloudpods_serverskus_tool.go @@ -0,0 +1,310 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// CloudpodsServerSkusTool 用于查询Cloudpods主机套餐规格列表的工具 +type CloudpodsServerSkusTool struct { + // adapter 用于与Cloudpods API进行交互 + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsServerSkusTool 创建一个新的CloudpodsServerSkusTool实例 +// +// 参数: +// - adapter: 用于与Cloudpods API交互的适配器 +// +// 返回值: +// - *CloudpodsServerSkusTool: CloudpodsServerSkusTool实例指针 +func NewCloudpodsServerSkusTool(adapter *adapters.CloudpodsAdapter) *CloudpodsServerSkusTool { + return &CloudpodsServerSkusTool{ + adapter: adapter, + } +} + +// GetTool 定义并返回查询主机套餐规格列表工具的元数据 +// +// 工具用途: +// +// 查询Cloudpods主机套餐规格列表,获取虚拟机规格信息 +// +// 参数说明: +// - limit: 返回结果数量限制,默认为20 +// - offset: 返回结果偏移量,默认为0 +// - search: 搜索关键词,可以按规格名称搜索 +// - cloudregion_ids: 云区域ID,多个用逗号分隔 +// - zone_ids: 可用区ID,多个用逗号分隔 +// - cpu_core_count: CPU核心数,多个用逗号分隔,如:1,2,4,8 +// - memory_size_mb: 内存大小MB,多个用逗号分隔,如:1024,2048,4096 +// - providers: 云平台提供商,多个用逗号分隔,如:OneCloud,Aliyun,Huawei +// - cpu_arch: CPU架构,多个用逗号分隔,如:x86,arm +// - ak: 用户登录cloudpods后获取的access key +// - sk: 用户登录cloudpods后获取的secret key +func (c *CloudpodsServerSkusTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_list_serverskus", + mcp.WithDescription("查询Cloudpods主机套餐规格列表,获取虚拟机规格信息"), + mcp.WithString("limit", mcp.Description("返回结果数量限制,默认为20")), + mcp.WithString("offset", mcp.Description("返回结果偏移量,默认为0")), + mcp.WithString("search", mcp.Description("搜索关键词,可以按规格名称搜索")), + mcp.WithString("cloudregion_ids", mcp.Description("云区域ID,多个用逗号分隔")), + mcp.WithString("zone_ids", mcp.Description("可用区ID,多个用逗号分隔")), + mcp.WithString("cpu_core_count", mcp.Description("CPU核心数,多个用逗号分隔,如:1,2,4,8")), + mcp.WithString("memory_size_mb", mcp.Description("内存大小MB,多个用逗号分隔,如:1024,2048,4096")), + mcp.WithString("providers", mcp.Description("云平台提供商,多个用逗号分隔,如:OneCloud,Aliyun,Huawei")), + mcp.WithString("cpu_arch", mcp.Description("CPU架构,多个用逗号分隔,如:x86,arm")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理查询主机套餐规格列表的请求 +// +// 参数: +// - ctx: 控制生命周期的上下文 +// - req: 包含查询参数的请求对象 +// +// 返回值: +// - *mcp.CallToolResult: 包含主机套餐规格列表的响应对象 +// - error: 可能的错误信息 +func (c *CloudpodsServerSkusTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 获取可选参数:返回结果数量限制,如果指定则转换为整数 + limit := 20 + if limitStr := req.GetString("limit", ""); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + // 获取可选参数:结果偏移量,如果指定则转换为整数 + offset := 0 + if offsetStr := req.GetString("offset", ""); offsetStr != "" { + if parsedOffset, err := strconv.Atoi(offsetStr); err == nil && parsedOffset >= 0 { + offset = parsedOffset + } + } + + // 获取可选参数:搜索关键词 + search := req.GetString("search", "") + + // 获取可选参数:云区域ID列表 + var cloudregionIds []string + if cloudregionIdsStr := req.GetString("cloudregion_ids", ""); cloudregionIdsStr != "" { + cloudregionIds = strings.Split(cloudregionIdsStr, ",") + for i, id := range cloudregionIds { + cloudregionIds[i] = strings.TrimSpace(id) + } + } + + // 获取可选参数:可用区ID列表 + var zoneIds []string + if zoneIdsStr := req.GetString("zone_ids", ""); zoneIdsStr != "" { + zoneIds = strings.Split(zoneIdsStr, ",") + for i, id := range zoneIds { + zoneIds[i] = strings.TrimSpace(id) + } + } + + // 获取可选参数:CPU核心数列表 + var cpuCoreCount []string + if cpuCoreCountStr := req.GetString("cpu_core_count", ""); cpuCoreCountStr != "" { + cpuCoreCount = strings.Split(cpuCoreCountStr, ",") + for i, count := range cpuCoreCount { + cpuCoreCount[i] = strings.TrimSpace(count) + } + } + + // 获取可选参数:内存大小列表(MB) + var memorySizeMB []string + if memorySizeMBStr := req.GetString("memory_size_mb", ""); memorySizeMBStr != "" { + memorySizeMB = strings.Split(memorySizeMBStr, ",") + for i, size := range memorySizeMB { + memorySizeMB[i] = strings.TrimSpace(size) + } + } + + // 获取可选参数:云平台提供商列表 + var providers []string + if providersStr := req.GetString("providers", ""); providersStr != "" { + providers = strings.Split(providersStr, ",") + for i, provider := range providers { + providers[i] = strings.TrimSpace(provider) + } + } + + // 获取可选参数:CPU架构列表 + var cpuArch []string + if cpuArchStr := req.GetString("cpu_arch", ""); cpuArchStr != "" { + cpuArch = strings.Split(cpuArchStr, ",") + for i, arch := range cpuArch { + cpuArch[i] = strings.TrimSpace(arch) + } + } + + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器查询主机套餐规格列表 + skusResponse, err := c.adapter.ListServerSkus(limit, offset, search, cloudregionIds, zoneIds, cpuCoreCount, memorySizeMB, providers, cpuArch, ak, sk) + if err != nil { + log.Errorf("Fail to query server skus: %s", err) + return nil, fmt.Errorf("fail to query server skus: %w", err) + } + + // 格式化查询结果 + formattedResult := c.formatServerSkusResult(skusResponse, limit, offset, search, cloudregionIds, zoneIds, cpuCoreCount, memorySizeMB, providers, cpuArch) + + // 将结果序列化为JSON格式 + resultJSON, err := json.MarshalIndent(formattedResult, "", " ") + if err != nil { + log.Errorf("Fail to serialize result: %s", err) + return nil, fmt.Errorf("fail to serialize result: %w", err) + } + + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回工具的名称标识符 +// +// 返回值: +// - string: 工具名称字符串,用于唯一标识该工具 +func (c *CloudpodsServerSkusTool) GetName() string { + return "cloudpods_list_serverskus" +} + +// formatServerSkusResult 格式化主机套餐规格列表的响应结果 +// +// 参数: +// - response: 原始主机套餐规格列表响应数据 +// - limit: 查询限制数量 +// - offset: 查询偏移量 +// - search: 搜索关键词 +// - cloudregionIds: 云区域ID列表 +// - zoneIds: 可用区ID列表 +// - cpuCoreCount: CPU核心数列表 +// - memorySizeMB: 内存大小列表(MB) +// - providers: 云平台提供商列表 +// - cpuArch: CPU架构列表 +// +// 返回值: +// - map[string]interface{}: 包含主机套餐规格列表的格式化结果 +func (c *CloudpodsServerSkusTool) formatServerSkusResult( + response *models.ServerSkuListResponse, + limit, offset int, + search string, + cloudregionIds, zoneIds, cpuCoreCount, memorySizeMB, providers, cpuArch []string, +) map[string]interface{} { + // 初始化格式化结果结构 + formatted := map[string]interface{}{ + "query_info": map[string]interface{}{ + "limit": limit, + "offset": offset, + "search": search, + "cloudregion_ids": cloudregionIds, + "zone_ids": zoneIds, + "cpu_core_count": cpuCoreCount, + "memory_size_mb": memorySizeMB, + "providers": providers, + "cpu_arch": cpuArch, + "total": response.Total, + "count": len(response.Serverskus), + }, + "serverskus": make([]map[string]interface{}, 0, len(response.Serverskus)), + } + + // 遍历主机套餐列表,构造每个主机套餐的详细信息 + for _, sku := range response.Serverskus { + skuInfo := map[string]interface{}{ + "id": sku.Id, + "name": sku.Name, + "description": sku.Description, + "status": sku.Status, + "enabled": sku.Enabled, + "provider": sku.Provider, + "cloud_env": sku.CloudEnv, + "cloudregion": sku.Cloudregion, + "cloudregion_id": sku.CloudregionId, + "zone": sku.Zone, + "zone_id": sku.ZoneId, + "zone_ext_id": sku.ZoneExtId, + "cpu_core_count": sku.CpuCoreCount, + "memory_size_mb": sku.MemorySizeMB, + "cpu_arch": sku.CpuArch, + "instance_type_family": sku.InstanceTypeFamily, + "instance_type_category": sku.InstanceTypeCategory, + "local_category": sku.LocalCategory, + "sys_disk_type": sku.SysDiskType, + "sys_disk_min_size_gb": sku.SysDiskMinSizeGB, + "sys_disk_max_size_gb": sku.SysDiskMaxSizeGB, + "sys_disk_resizable": sku.SysDiskResizable, + "data_disk_types": sku.DataDiskTypes, + "data_disk_max_count": sku.DataDiskMaxCount, + "attached_disk_count": sku.AttachedDiskCount, + "attached_disk_size_gb": sku.AttachedDiskSizeGB, + "attached_disk_type": sku.AttachedDiskType, + "nic_type": sku.NicType, + "nic_max_count": sku.NicMaxCount, + "gpu_attachable": sku.GpuAttachable, + "gpu_count": sku.GpuCount, + "gpu_max_count": sku.GpuMaxCount, + "gpu_spec": sku.GpuSpec, + "os_name": sku.OsName, + "postpaid_status": sku.PostpaidStatus, + "prepaid_status": sku.PrepaidStatus, + "total_guest_count": sku.TotalGuestCount, + "external_id": sku.ExternalId, + "source": sku.Source, + "is_emulated": sku.IsEmulated, + "region": sku.Region, + "region_id": sku.RegionId, + "region_ext_id": sku.RegionExtId, + "region_external_id": sku.RegionExternalId, + "md5": sku.Md5, + "metadata": sku.Metadata, + "progress": sku.Progress, + "can_delete": sku.CanDelete, + "can_update": sku.CanUpdate, + "update_version": sku.UpdateVersion, + "created_at": sku.CreatedAt, + "updated_at": sku.UpdatedAt, + "imported_at": sku.ImportedAt, + } + formatted["serverskus"] = append(formatted["serverskus"].([]map[string]interface{}), skuInfo) + } + + // 构造摘要信息 + formatted["summary"] = map[string]interface{}{ + "total_serverskus": response.Total, + "returned_count": len(response.Serverskus), + "has_more": response.Total > int64(offset+len(response.Serverskus)), + "next_offset": offset + len(response.Serverskus), + } + + return formatted +} diff --git a/pkg/mcp-server/tools/cloudpods_storages_tool.go b/pkg/mcp-server/tools/cloudpods_storages_tool.go new file mode 100644 index 0000000000..989e9c2f35 --- /dev/null +++ b/pkg/mcp-server/tools/cloudpods_storages_tool.go @@ -0,0 +1,319 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + + "github.com/mark3labs/mcp-go/mcp" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// CloudpodsStoragesTool 用于查询Cloudpods块存储列表的工具 +type CloudpodsStoragesTool struct { + // adapter 用于与Cloudpods API进行交互 + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsStoragesTool 创建一个新的CloudpodsStoragesTool实例 +// +// 参数: +// - adapter: 用于与Cloudpods API交互的适配器 +// +// 返回值: +// - *CloudpodsStoragesTool: CloudpodsStoragesTool实例指针 +func NewCloudpodsStoragesTool(adapter *adapters.CloudpodsAdapter) *CloudpodsStoragesTool { + return &CloudpodsStoragesTool{ + adapter: adapter, + } +} + +// GetTool 定义并返回查询块存储列表工具的元数据 +// +// 工具用途: +// +// 查询Cloudpods块存储列表,获取存储资源信息 +// +// 参数说明: +// - limit: 返回结果数量限制,默认为20 +// - offset: 返回结果偏移量,默认为0 +// - search: 搜索关键词,可以按存储名称搜索 +// - cloudregion_ids: 云区域ID,多个用逗号分隔 +// - zone_ids: 可用区ID,多个用逗号分隔 +// - providers: 云平台提供商,多个用逗号分隔,如:OneCloud,Aliyun,Huawei +// - storage_types: 存储类型,多个用逗号分隔,如:local,rbd,nfs,cephfs +// - host_id: 主机ID,过滤关联指定主机的存储 +// - ak: 用户登录cloudpods后获取的access key +// - sk: 用户登录cloudpods后获取的secret key +func (c *CloudpodsStoragesTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_list_storages", + mcp.WithDescription("查询Cloudpods块存储列表,获取存储资源信息"), + mcp.WithString("limit", mcp.Description("返回结果数量限制,默认为20")), + mcp.WithString("offset", mcp.Description("返回结果偏移量,默认为0")), + mcp.WithString("search", mcp.Description("搜索关键词,可以按存储名称搜索")), + mcp.WithString("cloudregion_ids", mcp.Description("云区域ID,多个用逗号分隔")), + mcp.WithString("zone_ids", mcp.Description("可用区ID,多个用逗号分隔")), + mcp.WithString("providers", mcp.Description("云平台提供商,多个用逗号分隔,如:OneCloud,Aliyun,Huawei")), + mcp.WithString("storage_types", mcp.Description("存储类型,多个用逗号分隔,如:local,rbd,nfs,cephfs")), + mcp.WithString("host_id", mcp.Description("主机ID,过滤关联指定主机的存储")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理查询块存储列表的请求 +// +// 参数: +// - ctx: 控制生命周期的上下文 +// - req: 包含查询参数的请求对象 +// +// 返回值: +// - *mcp.CallToolResult: 包含块存储列表的响应对象 +// - error: 可能的错误信息 +func (c *CloudpodsStoragesTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 获取可选参数:返回结果数量限制,如果指定则转换为整数 + limit := 20 + if limitStr := req.GetString("limit", ""); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + // 获取可选参数:结果偏移量,如果指定则转换为整数 + offset := 0 + if offsetStr := req.GetString("offset", ""); offsetStr != "" { + if parsedOffset, err := strconv.Atoi(offsetStr); err == nil && parsedOffset >= 0 { + offset = parsedOffset + } + } + + // 获取可选参数:搜索关键词 + search := req.GetString("search", "") + + // 获取可选参数:云区域ID列表 + var cloudregionIds []string + if cloudregionIdsStr := req.GetString("cloudregion_ids", ""); cloudregionIdsStr != "" { + cloudregionIds = strings.Split(cloudregionIdsStr, ",") + for i, id := range cloudregionIds { + cloudregionIds[i] = strings.TrimSpace(id) + } + } + + // 获取可选参数:可用区ID列表 + var zoneIds []string + if zoneIdsStr := req.GetString("zone_ids", ""); zoneIdsStr != "" { + zoneIds = strings.Split(zoneIdsStr, ",") + for i, id := range zoneIds { + zoneIds[i] = strings.TrimSpace(id) + } + } + + // 获取可选参数:云平台提供商列表 + var providers []string + if providersStr := req.GetString("providers", ""); providersStr != "" { + providers = strings.Split(providersStr, ",") + for i, provider := range providers { + providers[i] = strings.TrimSpace(provider) + } + } + + // 获取可选参数:存储类型列表 + var storageTypes []string + if storageTypesStr := req.GetString("storage_types", ""); storageTypesStr != "" { + storageTypes = strings.Split(storageTypesStr, ",") + for i, storageType := range storageTypes { + storageTypes[i] = strings.TrimSpace(storageType) + } + } + + // 获取可选参数:主机ID + hostId := req.GetString("host_id", "") + + // 获取可选参数:访问凭证 + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器查询块存储列表 + storagesResponse, err := c.adapter.ListStorages(limit, offset, search, cloudregionIds, zoneIds, providers, storageTypes, hostId, ak, sk) + if err != nil { + log.Errorf("Fail to query storage: %s", err) + return nil, fmt.Errorf("fail to query storage: %w", err) + } + + // 格式化查询结果 + formattedResult := c.formatStoragesResult(storagesResponse, limit, offset, search, cloudregionIds, zoneIds, providers, storageTypes, hostId) + + // 将结果序列化为JSON格式 + resultJSON, err := json.MarshalIndent(formattedResult, "", " ") + if err != nil { + log.Errorf("Fail to serialize result: %s", err) + return nil, fmt.Errorf("fail to serialize result: %w", err) + } + + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回工具的名称标识符 +// +// 返回值: +// - string: 工具名称字符串,用于唯一标识该工具 +func (c *CloudpodsStoragesTool) GetName() string { + return "cloudpods_list_storages" +} + +// formatStoragesResult 格式化块存储列表的响应结果 +// +// 参数: +// - response: 原始响应数据 +// - limit: 查询限制 +// - offset: 查询偏移量 +// - search: 搜索关键词 +// - cloudregionIds: 云区域ID列表 +// - zoneIds: 可用区ID列表 +// - providers: 云平台提供商列表 +// - storageTypes: 存储类型列表 +// - hostId: 主机ID +// +// 返回值: +// - map[string]interface{}: 包含块存储列表的格式化结果 +func (c *CloudpodsStoragesTool) formatStoragesResult( + response *models.StorageListResponse, + limit, offset int, + search string, + cloudregionIds, zoneIds, providers, storageTypes []string, + hostId string, +) map[string]interface{} { + // 初始化格式化结果结构 + formatted := map[string]interface{}{ + "query_info": map[string]interface{}{ + "limit": limit, + "offset": offset, + "search": search, + "cloudregion_ids": cloudregionIds, + "zone_ids": zoneIds, + "providers": providers, + "storage_types": storageTypes, + "host_id": hostId, + "total": response.Total, + "count": len(response.Storages), + }, + "storages": make([]map[string]interface{}, 0, len(response.Storages)), + } + + // 遍历块存储列表,构造每个块存储的详细信息 + for _, storage := range response.Storages { + capacityGB := float64(storage.Capacity) / 1024 + usedCapacityGB := float64(storage.UsedCapacity) / 1024 + freeCapacityGB := float64(storage.FreeCapacity) / 1024 + actualUsedGB := float64(storage.ActualCapacityUsed) / 1024 + + storageInfo := map[string]interface{}{ + "id": storage.Id, + "name": storage.Name, + "description": storage.Description, + "status": storage.Status, + "enabled": storage.Enabled, + "storage_type": storage.StorageType, + "medium_type": storage.MediumType, + "provider": storage.Provider, + "brand": storage.Brand, + "cloud_env": storage.CloudEnv, + "cloudregion": storage.Cloudregion, + "cloudregion_id": storage.CloudregionId, + "zone": storage.Zone, + "zone_id": storage.ZoneId, + "zone_ext_id": storage.ZoneExtId, + "capacity_mb": storage.Capacity, + "capacity_gb": fmt.Sprintf("%.2f GB", capacityGB), + "used_capacity_mb": storage.UsedCapacity, + "used_capacity_gb": fmt.Sprintf("%.2f GB", usedCapacityGB), + "free_capacity_mb": storage.FreeCapacity, + "free_capacity_gb": fmt.Sprintf("%.2f GB", freeCapacityGB), + "actual_capacity_used": storage.ActualCapacityUsed, + "actual_used_gb": fmt.Sprintf("%.2f GB", actualUsedGB), + "virtual_capacity": storage.VirtualCapacity, + "waste_capacity": storage.WasteCapacity, + "reserved": storage.Reserved, + "commit_bound": storage.CommitBound, + "commit_rate": storage.CommitRate, + "cmtbound": storage.Cmtbound, + "is_sys_disk_store": storage.IsSysDiskStore, + "is_public": storage.IsPublic, + "is_emulated": storage.IsEmulated, + "disk_count": storage.DiskCount, + "host_count": storage.HostCount, + "snapshot_count": storage.SnapshotCount, + "master_host": storage.MasterHost, + "master_host_name": storage.MasterHostName, + "storagecache_id": storage.StoragecacheId, + "account": storage.Account, + "account_id": storage.AccountId, + "account_status": storage.AccountStatus, + "account_health_status": storage.AccountHealthStatus, + "account_read_only": storage.AccountReadOnly, + "manager": storage.Manager, + "manager_id": storage.ManagerId, + "manager_domain": storage.ManagerDomain, + "manager_domain_id": storage.ManagerDomainId, + "manager_project": storage.ManagerProject, + "manager_project_id": storage.ManagerProjectId, + "external_id": storage.ExternalId, + "source": storage.Source, + "region": storage.Region, + "region_id": storage.RegionId, + "region_ext_id": storage.RegionExtId, + "region_external_id": storage.RegionExternalId, + "environment": storage.Environment, + "domain_id": storage.DomainId, + "domain_src": storage.DomainSrc, + "project_domain": storage.ProjectDomain, + "public_scope": storage.PublicScope, + "public_src": storage.PublicSrc, + "shared_domains": storage.SharedDomains, + "shared_projects": storage.SharedProjects, + "schedtags": storage.Schedtags, + "hosts": storage.Hosts, + "storage_conf": storage.StorageConf, + "metadata": storage.Metadata, + "progress": storage.Progress, + "can_delete": storage.CanDelete, + "can_update": storage.CanUpdate, + "update_version": storage.UpdateVersion, + "created_at": storage.CreatedAt, + "updated_at": storage.UpdatedAt, + "imported_at": storage.ImportedAt, + } + formatted["storages"] = append(formatted["storages"].([]map[string]interface{}), storageInfo) + } + + // 构造摘要信息 + formatted["summary"] = map[string]interface{}{ + "total_storages": response.Total, // 总存储数量 + "returned_count": len(response.Storages), // 当前返回的存储数量 + "has_more": response.Total > int64(offset+len(response.Storages)), // 是否还有更多数据 + "next_offset": offset + len(response.Storages), // 下一页的偏移量 + } + + return formatted +} diff --git a/pkg/mcp-server/tools/cloudpods_vpcs_tool.go b/pkg/mcp-server/tools/cloudpods_vpcs_tool.go new file mode 100644 index 0000000000..af65b51253 --- /dev/null +++ b/pkg/mcp-server/tools/cloudpods_vpcs_tool.go @@ -0,0 +1,239 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + + "github.com/mark3labs/mcp-go/mcp" + + "yunion.io/x/log" + + "yunion.io/x/onecloud/pkg/mcp-server/adapters" + "yunion.io/x/onecloud/pkg/mcp-server/models" +) + +// CloudpodsVPCsTool 用于查询Cloudpods VPC列表的工具 +// +// 字段: +// - adapter: 用于与Cloudpods API进行交互的适配器 +type CloudpodsVPCsTool struct { + adapter *adapters.CloudpodsAdapter +} + +// NewCloudpodsVPCsTool 创建CloudpodsVPCsTool实例 +// +// 参数: +// - adapter: 用于与Cloudpods API交互的适配器 +// +// 返回值: +// - *CloudpodsVPCsTool: CloudpodsVPCsTool实例指针 +func NewCloudpodsVPCsTool(adapter *adapters.CloudpodsAdapter) *CloudpodsVPCsTool { + return &CloudpodsVPCsTool{ + adapter: adapter, + } +} + +// GetTool 定义并返回查询VPC列表工具的元数据 +// +// 工具用途: +// +// 查询Cloudpods VPC列表,获取虚拟私有网络信息 +// +// 参数说明: +// - limit: 返回结果数量限制,默认为20 +// - offset: 返回结果偏移量,默认为0 +// - search: 搜索关键词,可以按VPC名称搜索 +// - cloudregion_id: 过滤指定云区域的VPC资源 +// - ak: 用户登录cloudpods后获取的access key +// - sk: 用户登录cloudpods后获取的secret key +func (c *CloudpodsVPCsTool) GetTool() mcp.Tool { + return mcp.NewTool( + "cloudpods_list_vpcs", + mcp.WithDescription("查询Cloudpods VPC列表,获取虚拟私有网络信息"), + mcp.WithString("limit", mcp.Description("返回结果数量限制,默认为20")), + mcp.WithString("offset", mcp.Description("返回结果偏移量,默认为0")), + mcp.WithString("search", mcp.Description("搜索关键词,可以按VPC名称搜索")), + mcp.WithString("cloudregion_id", mcp.Description("过滤指定云区域的VPC资源")), + mcp.WithString("ak", mcp.Description("用户登录cloudpods后获取的access key")), + mcp.WithString("sk", mcp.Description("用户登录cloudpods后获取的secret key")), + ) +} + +// Handle 处理查询VPC列表的请求 +// +// 参数: +// - ctx: 控制生命周期的上下文 +// - req: 包含查询参数的请求对象 +// +// 返回值: +// - *mcp.CallToolResult: 包含VPC列表的响应对象 +// - error: 可能的错误信息 +func (c *CloudpodsVPCsTool) Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // 获取可选参数:返回结果数量限制,如果指定则转换为整数 + limit := 20 + if limitStr := req.GetString("limit", ""); limitStr != "" { + if parsedLimit, err := strconv.Atoi(limitStr); err == nil && parsedLimit > 0 { + limit = parsedLimit + } + } + + // 获取可选参数:结果偏移量,如果指定则转换为整数 + offset := 0 + if offsetStr := req.GetString("offset", ""); offsetStr != "" { + if parsedOffset, err := strconv.Atoi(offsetStr); err == nil && parsedOffset >= 0 { + offset = parsedOffset + } + } + + // 获取可选参数:搜索关键词 + search := req.GetString("search", "") + // 获取可选参数:云区域ID + cloudRegionID := req.GetString("cloudregion_id", "") + + // 获取可选参数:访问凭证 + ak := req.GetString("ak", "") + sk := req.GetString("sk", "") + + // 调用适配器查询VPC列表 + vpcsResponse, err := c.adapter.ListVPCs(limit, offset, search, cloudRegionID, ak, sk) + if err != nil { + log.Errorf("Fail to query vpc: %s", err) + return nil, fmt.Errorf("fail to query vpc: %w", err) + } + + // 格式化查询结果 + formattedResult := c.formatVPCsResult(vpcsResponse, limit, offset, search, cloudRegionID) + + // 将结果序列化为JSON格式 + resultJSON, err := json.MarshalIndent(formattedResult, "", " ") + if err != nil { + log.Errorf("Fail to serialize result: %s", err) + return nil, fmt.Errorf("fail to serialize result: %w", err) + } + + return mcp.NewToolResultText(string(resultJSON)), nil +} + +// GetName 返回工具的名称标识符 +// +// 返回值: +// - string: 工具名称字符串,用于唯一标识该工具 +func (c *CloudpodsVPCsTool) GetName() string { + return "cloudpods_list_vpcs" +} + +// formatVPCsResult 格式化VPC列表的响应结果 +// +// 参数: +// - response: 原始响应数据 +// - limit: 查询限制 +// - offset: 查询偏移量 +// - search: 搜索关键词 +// - cloudRegionID: 云区域ID +// +// 返回值: +// - map[string]interface{}: 包含VPC列表的格式化结果 +func (c *CloudpodsVPCsTool) formatVPCsResult(response *models.VpcListResponse, limit, offset int, search, cloudRegionID string) map[string]interface{} { + // 初始化格式化结果结构 + formatted := map[string]interface{}{ + "query_info": map[string]interface{}{ + "limit": limit, + "offset": offset, + "search": search, + "cloudregion_id": cloudRegionID, + "total": response.Total, + "count": len(response.Vpcs), + }, + "vpcs": make([]map[string]interface{}, 0, len(response.Vpcs)), + } + + // 遍历VPC列表,构造每个VPC的详细信息 + for _, vpc := range response.Vpcs { + vpcInfo := map[string]interface{}{ + "id": vpc.Id, + "name": vpc.Name, + "description": vpc.Description, + "cidr_block": vpc.CidrBlock, + "cidr_block6": vpc.CidrBlock6, + "status": vpc.Status, + "enabled": vpc.Enabled, + "is_default": vpc.IsDefault, + "is_public": vpc.IsPublic, + "provider": vpc.Provider, + "brand": vpc.Brand, + "cloud_env": vpc.CloudEnv, + "environment": vpc.Environment, + "cloudregion": vpc.Cloudregion, + "cloudregion_id": vpc.CloudregionId, + "region": vpc.Region, + "region_id": vpc.RegionId, + "external_id": vpc.ExternalId, + "external_access_mode": vpc.ExternalAccessMode, + "globalvpc": vpc.Globalvpc, + "globalvpc_id": vpc.GlobalvpcId, + "account": vpc.Account, + "account_id": vpc.AccountId, + "account_status": vpc.AccountStatus, + "account_health_status": vpc.AccountHealthStatus, + "manager": vpc.Manager, + "manager_id": vpc.ManagerId, + "manager_domain": vpc.ManagerDomain, + "manager_domain_id": vpc.ManagerDomainId, + "manager_project": vpc.ManagerProject, + "manager_project_id": vpc.ManagerProjectId, + "network_count": vpc.NetworkCount, + "wire_count": vpc.WireCount, + "dns_zone_count": vpc.DnsZoneCount, + "natgateway_count": vpc.NatgatewayCount, + "routetable_count": vpc.RoutetableCount, + "accept_vpc_peer_count": vpc.AcceptVpcPeerCount, + "request_vpc_peer_count": vpc.RequestVpcPeerCount, + "direct": vpc.Direct, + "domain_id": vpc.DomainId, + "domain_src": vpc.DomainSrc, + "project_domain": vpc.ProjectDomain, + "public_scope": vpc.PublicScope, + "public_src": vpc.PublicSrc, + "region_ext_id": vpc.RegionExtId, + "region_external_id": vpc.RegionExternalId, + "source": vpc.Source, + "progress": vpc.Progress, + "shared_domains": vpc.SharedDomains, + "shared_projects": vpc.SharedProjects, + "can_delete": vpc.CanDelete, + "can_update": vpc.CanUpdate, + "is_emulated": vpc.IsEmulated, + "metadata": vpc.Metadata, + "created_at": vpc.CreatedAt, + "updated_at": vpc.UpdatedAt, + "imported_at": vpc.ImportedAt, + } + formatted["vpcs"] = append(formatted["vpcs"].([]map[string]interface{}), vpcInfo) + } + + // 构造摘要信息 + formatted["summary"] = map[string]interface{}{ + "total_vpcs": response.Total, + "returned_count": len(response.Vpcs), + "has_more": response.Total > int64(offset+len(response.Vpcs)), + "next_offset": offset + len(response.Vpcs), + } + + return formatted +} diff --git a/pkg/mcp-server/tools/doc.go b/pkg/mcp-server/tools/doc.go new file mode 100644 index 0000000000..1b8062821d --- /dev/null +++ b/pkg/mcp-server/tools/doc.go @@ -0,0 +1 @@ +package tools // import "yunion.io/x/onecloud/pkg/mcp-server/tools" diff --git a/pkg/mcp-server/tools/tools.go b/pkg/mcp-server/tools/tools.go new file mode 100644 index 0000000000..c187e69117 --- /dev/null +++ b/pkg/mcp-server/tools/tools.go @@ -0,0 +1,31 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// Tool 是所有工具的接口,定义了工具的基本方法 +// GetTool 返回 MCP 工具定义 +// Handle 处理工具调用请求 +// GetName 返回工具名称 +type Tool interface { + GetTool() mcp.Tool + Handle(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) + GetName() string +} diff --git a/vendor/github.com/bahlo/generic-list-go/LICENSE b/vendor/github.com/bahlo/generic-list-go/LICENSE new file mode 100644 index 0000000000..6a66aea5ea --- /dev/null +++ b/vendor/github.com/bahlo/generic-list-go/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/bahlo/generic-list-go/README.md b/vendor/github.com/bahlo/generic-list-go/README.md new file mode 100644 index 0000000000..68bbce9fba --- /dev/null +++ b/vendor/github.com/bahlo/generic-list-go/README.md @@ -0,0 +1,5 @@ +# generic-list-go [![CI](https://github.com/bahlo/generic-list-go/actions/workflows/ci.yml/badge.svg)](https://github.com/bahlo/generic-list-go/actions/workflows/ci.yml) + +Go [container/list](https://pkg.go.dev/container/list) but with generics. + +The code is based on `container/list` in `go1.18beta2`. diff --git a/vendor/github.com/bahlo/generic-list-go/list.go b/vendor/github.com/bahlo/generic-list-go/list.go new file mode 100644 index 0000000000..a06a7c6129 --- /dev/null +++ b/vendor/github.com/bahlo/generic-list-go/list.go @@ -0,0 +1,235 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package list implements a doubly linked list. +// +// To iterate over a list (where l is a *List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e.Value +// } +// +package list + +// Element is an element of a linked list. +type Element[T any] struct { + // Next and previous pointers in the doubly-linked list of elements. + // To simplify the implementation, internally a list l is implemented + // as a ring, such that &l.root is both the next element of the last + // list element (l.Back()) and the previous element of the first list + // element (l.Front()). + next, prev *Element[T] + + // The list to which this element belongs. + list *List[T] + + // The value stored with this element. + Value T +} + +// Next returns the next list element or nil. +func (e *Element[T]) Next() *Element[T] { + if p := e.next; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// Prev returns the previous list element or nil. +func (e *Element[T]) Prev() *Element[T] { + if p := e.prev; e.list != nil && p != &e.list.root { + return p + } + return nil +} + +// List represents a doubly linked list. +// The zero value for List is an empty list ready to use. +type List[T any] struct { + root Element[T] // sentinel list element, only &root, root.prev, and root.next are used + len int // current list length excluding (this) sentinel element +} + +// Init initializes or clears list l. +func (l *List[T]) Init() *List[T] { + l.root.next = &l.root + l.root.prev = &l.root + l.len = 0 + return l +} + +// New returns an initialized list. +func New[T any]() *List[T] { return new(List[T]).Init() } + +// Len returns the number of elements of list l. +// The complexity is O(1). +func (l *List[T]) Len() int { return l.len } + +// Front returns the first element of list l or nil if the list is empty. +func (l *List[T]) Front() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.next +} + +// Back returns the last element of list l or nil if the list is empty. +func (l *List[T]) Back() *Element[T] { + if l.len == 0 { + return nil + } + return l.root.prev +} + +// lazyInit lazily initializes a zero List value. +func (l *List[T]) lazyInit() { + if l.root.next == nil { + l.Init() + } +} + +// insert inserts e after at, increments l.len, and returns e. +func (l *List[T]) insert(e, at *Element[T]) *Element[T] { + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e + e.list = l + l.len++ + return e +} + +// insertValue is a convenience wrapper for insert(&Element{Value: v}, at). +func (l *List[T]) insertValue(v T, at *Element[T]) *Element[T] { + return l.insert(&Element[T]{Value: v}, at) +} + +// remove removes e from its list, decrements l.len +func (l *List[T]) remove(e *Element[T]) { + e.prev.next = e.next + e.next.prev = e.prev + e.next = nil // avoid memory leaks + e.prev = nil // avoid memory leaks + e.list = nil + l.len-- +} + +// move moves e to next to at. +func (l *List[T]) move(e, at *Element[T]) { + if e == at { + return + } + e.prev.next = e.next + e.next.prev = e.prev + + e.prev = at + e.next = at.next + e.prev.next = e + e.next.prev = e +} + +// Remove removes e from l if e is an element of list l. +// It returns the element value e.Value. +// The element must not be nil. +func (l *List[T]) Remove(e *Element[T]) T { + if e.list == l { + // if e.list == l, l must have been initialized when e was inserted + // in l or l == nil (e is a zero Element) and l.remove will crash + l.remove(e) + } + return e.Value +} + +// PushFront inserts a new element e with value v at the front of list l and returns e. +func (l *List[T]) PushFront(v T) *Element[T] { + l.lazyInit() + return l.insertValue(v, &l.root) +} + +// PushBack inserts a new element e with value v at the back of list l and returns e. +func (l *List[T]) PushBack(v T) *Element[T] { + l.lazyInit() + return l.insertValue(v, l.root.prev) +} + +// InsertBefore inserts a new element e with value v immediately before mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *List[T]) InsertBefore(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark.prev) +} + +// InsertAfter inserts a new element e with value v immediately after mark and returns e. +// If mark is not an element of l, the list is not modified. +// The mark must not be nil. +func (l *List[T]) InsertAfter(v T, mark *Element[T]) *Element[T] { + if mark.list != l { + return nil + } + // see comment in List.Remove about initialization of l + return l.insertValue(v, mark) +} + +// MoveToFront moves element e to the front of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[T]) MoveToFront(e *Element[T]) { + if e.list != l || l.root.next == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, &l.root) +} + +// MoveToBack moves element e to the back of list l. +// If e is not an element of l, the list is not modified. +// The element must not be nil. +func (l *List[T]) MoveToBack(e *Element[T]) { + if e.list != l || l.root.prev == e { + return + } + // see comment in List.Remove about initialization of l + l.move(e, l.root.prev) +} + +// MoveBefore moves element e to its new position before mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[T]) MoveBefore(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } + l.move(e, mark.prev) +} + +// MoveAfter moves element e to its new position after mark. +// If e or mark is not an element of l, or e == mark, the list is not modified. +// The element and mark must not be nil. +func (l *List[T]) MoveAfter(e, mark *Element[T]) { + if e.list != l || e == mark || mark.list != l { + return + } + l.move(e, mark) +} + +// PushBackList inserts a copy of another list at the back of list l. +// The lists l and other may be the same. They must not be nil. +func (l *List[T]) PushBackList(other *List[T]) { + l.lazyInit() + for i, e := other.Len(), other.Front(); i > 0; i, e = i-1, e.Next() { + l.insertValue(e.Value, l.root.prev) + } +} + +// PushFrontList inserts a copy of another list at the front of list l. +// The lists l and other may be the same. They must not be nil. +func (l *List[T]) PushFrontList(other *List[T]) { + l.lazyInit() + for i, e := other.Len(), other.Back(); i > 0; i, e = i-1, e.Prev() { + l.insertValue(e.Value, &l.root) + } +} diff --git a/vendor/github.com/buger/jsonparser/.gitignore b/vendor/github.com/buger/jsonparser/.gitignore new file mode 100644 index 0000000000..5598d8a569 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/.gitignore @@ -0,0 +1,12 @@ + +*.test + +*.out + +*.mprof + +.idea + +vendor/github.com/buger/goterm/ +prof.cpu +prof.mem diff --git a/vendor/github.com/buger/jsonparser/.travis.yml b/vendor/github.com/buger/jsonparser/.travis.yml new file mode 100644 index 0000000000..dbfb7cf988 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/.travis.yml @@ -0,0 +1,11 @@ +language: go +arch: + - amd64 + - ppc64le +go: + - 1.7.x + - 1.8.x + - 1.9.x + - 1.10.x + - 1.11.x +script: go test -v ./. diff --git a/vendor/github.com/buger/jsonparser/Dockerfile b/vendor/github.com/buger/jsonparser/Dockerfile new file mode 100644 index 0000000000..37fc9fd0b4 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/Dockerfile @@ -0,0 +1,12 @@ +FROM golang:1.6 + +RUN go get github.com/Jeffail/gabs +RUN go get github.com/bitly/go-simplejson +RUN go get github.com/pquerna/ffjson +RUN go get github.com/antonholmquist/jason +RUN go get github.com/mreiferson/go-ujson +RUN go get -tags=unsafe -u github.com/ugorji/go/codec +RUN go get github.com/mailru/easyjson + +WORKDIR /go/src/github.com/buger/jsonparser +ADD . /go/src/github.com/buger/jsonparser \ No newline at end of file diff --git a/vendor/github.com/buger/jsonparser/LICENSE b/vendor/github.com/buger/jsonparser/LICENSE new file mode 100644 index 0000000000..ac25aeb7da --- /dev/null +++ b/vendor/github.com/buger/jsonparser/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2016 Leonid Bugaev + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/buger/jsonparser/Makefile b/vendor/github.com/buger/jsonparser/Makefile new file mode 100644 index 0000000000..e843368cf1 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/Makefile @@ -0,0 +1,36 @@ +SOURCE = parser.go +CONTAINER = jsonparser +SOURCE_PATH = /go/src/github.com/buger/jsonparser +BENCHMARK = JsonParser +BENCHTIME = 5s +TEST = . +DRUN = docker run -v `pwd`:$(SOURCE_PATH) -i -t $(CONTAINER) + +build: + docker build -t $(CONTAINER) . + +race: + $(DRUN) --env GORACE="halt_on_error=1" go test ./. $(ARGS) -v -race -timeout 15s + +bench: + $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -benchtime $(BENCHTIME) -v + +bench_local: + $(DRUN) go test $(LDFLAGS) -test.benchmem -bench . $(ARGS) -benchtime $(BENCHTIME) -v + +profile: + $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -memprofile mem.mprof -v + $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -cpuprofile cpu.out -v + $(DRUN) go test $(LDFLAGS) -test.benchmem -bench $(BENCHMARK) ./benchmark/ $(ARGS) -c + +test: + $(DRUN) go test $(LDFLAGS) ./ -run $(TEST) -timeout 10s $(ARGS) -v + +fmt: + $(DRUN) go fmt ./... + +vet: + $(DRUN) go vet ./. + +bash: + $(DRUN) /bin/bash \ No newline at end of file diff --git a/vendor/github.com/buger/jsonparser/README.md b/vendor/github.com/buger/jsonparser/README.md new file mode 100644 index 0000000000..d7e0ec397a --- /dev/null +++ b/vendor/github.com/buger/jsonparser/README.md @@ -0,0 +1,365 @@ +[![Go Report Card](https://goreportcard.com/badge/github.com/buger/jsonparser)](https://goreportcard.com/report/github.com/buger/jsonparser) ![License](https://img.shields.io/dub/l/vibe-d.svg) +# Alternative JSON parser for Go (10x times faster standard library) + +It does not require you to know the structure of the payload (eg. create structs), and allows accessing fields by providing the path to them. It is up to **10 times faster** than standard `encoding/json` package (depending on payload size and usage), **allocates no memory**. See benchmarks below. + +## Rationale +Originally I made this for a project that relies on a lot of 3rd party APIs that can be unpredictable and complex. +I love simplicity and prefer to avoid external dependecies. `encoding/json` requires you to know exactly your data structures, or if you prefer to use `map[string]interface{}` instead, it will be very slow and hard to manage. +I investigated what's on the market and found that most libraries are just wrappers around `encoding/json`, there is few options with own parsers (`ffjson`, `easyjson`), but they still requires you to create data structures. + + +Goal of this project is to push JSON parser to the performance limits and not sacrifice with compliance and developer user experience. + +## Example +For the given JSON our goal is to extract the user's full name, number of github followers and avatar. + +```go +import "github.com/buger/jsonparser" + +... + +data := []byte(`{ + "person": { + "name": { + "first": "Leonid", + "last": "Bugaev", + "fullName": "Leonid Bugaev" + }, + "github": { + "handle": "buger", + "followers": 109 + }, + "avatars": [ + { "url": "https://avatars1.githubusercontent.com/u/14009?v=3&s=460", "type": "thumbnail" } + ] + }, + "company": { + "name": "Acme" + } +}`) + +// You can specify key path by providing arguments to Get function +jsonparser.Get(data, "person", "name", "fullName") + +// There is `GetInt` and `GetBoolean` helpers if you exactly know key data type +jsonparser.GetInt(data, "person", "github", "followers") + +// When you try to get object, it will return you []byte slice pointer to data containing it +// In `company` it will be `{"name": "Acme"}` +jsonparser.Get(data, "company") + +// If the key doesn't exist it will throw an error +var size int64 +if value, err := jsonparser.GetInt(data, "company", "size"); err == nil { + size = value +} + +// You can use `ArrayEach` helper to iterate items [item1, item2 .... itemN] +jsonparser.ArrayEach(data, func(value []byte, dataType jsonparser.ValueType, offset int, err error) { + fmt.Println(jsonparser.Get(value, "url")) +}, "person", "avatars") + +// Or use can access fields by index! +jsonparser.GetString(data, "person", "avatars", "[0]", "url") + +// You can use `ObjectEach` helper to iterate objects { "key1":object1, "key2":object2, .... "keyN":objectN } +jsonparser.ObjectEach(data, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error { + fmt.Printf("Key: '%s'\n Value: '%s'\n Type: %s\n", string(key), string(value), dataType) + return nil +}, "person", "name") + +// The most efficient way to extract multiple keys is `EachKey` + +paths := [][]string{ + []string{"person", "name", "fullName"}, + []string{"person", "avatars", "[0]", "url"}, + []string{"company", "url"}, +} +jsonparser.EachKey(data, func(idx int, value []byte, vt jsonparser.ValueType, err error){ + switch idx { + case 0: // []string{"person", "name", "fullName"} + ... + case 1: // []string{"person", "avatars", "[0]", "url"} + ... + case 2: // []string{"company", "url"}, + ... + } +}, paths...) + +// For more information see docs below +``` + +## Need to speedup your app? + +I'm available for consulting and can help you push your app performance to the limits. Ping me at: leonsbox@gmail.com. + +## Reference + +Library API is really simple. You just need the `Get` method to perform any operation. The rest is just helpers around it. + +You also can view API at [godoc.org](https://godoc.org/github.com/buger/jsonparser) + + +### **`Get`** +```go +func Get(data []byte, keys ...string) (value []byte, dataType jsonparser.ValueType, offset int, err error) +``` +Receives data structure, and key path to extract value from. + +Returns: +* `value` - Pointer to original data structure containing key value, or just empty slice if nothing found or error +* `dataType` - Can be: `NotExist`, `String`, `Number`, `Object`, `Array`, `Boolean` or `Null` +* `offset` - Offset from provided data structure where key value ends. Used mostly internally, for example for `ArrayEach` helper. +* `err` - If the key is not found or any other parsing issue, it should return error. If key not found it also sets `dataType` to `NotExist` + +Accepts multiple keys to specify path to JSON value (in case of quering nested structures). +If no keys are provided it will try to extract the closest JSON value (simple ones or object/array), useful for reading streams or arrays, see `ArrayEach` implementation. + +Note that keys can be an array indexes: `jsonparser.GetInt("person", "avatars", "[0]", "url")`, pretty cool, yeah? + +### **`GetString`** +```go +func GetString(data []byte, keys ...string) (val string, err error) +``` +Returns strings properly handing escaped and unicode characters. Note that this will cause additional memory allocations. + +### **`GetUnsafeString`** +If you need string in your app, and ready to sacrifice with support of escaped symbols in favor of speed. It returns string mapped to existing byte slice memory, without any allocations: +```go +s, _, := jsonparser.GetUnsafeString(data, "person", "name", "title") +switch s { + case 'CEO': + ... + case 'Engineer' + ... + ... +} +``` +Note that `unsafe` here means that your string will exist until GC will free underlying byte slice, for most of cases it means that you can use this string only in current context, and should not pass it anywhere externally: through channels or any other way. + + +### **`GetBoolean`**, **`GetInt`** and **`GetFloat`** +```go +func GetBoolean(data []byte, keys ...string) (val bool, err error) + +func GetFloat(data []byte, keys ...string) (val float64, err error) + +func GetInt(data []byte, keys ...string) (val int64, err error) +``` +If you know the key type, you can use the helpers above. +If key data type do not match, it will return error. + +### **`ArrayEach`** +```go +func ArrayEach(data []byte, cb func(value []byte, dataType jsonparser.ValueType, offset int, err error), keys ...string) +``` +Needed for iterating arrays, accepts a callback function with the same return arguments as `Get`. + +### **`ObjectEach`** +```go +func ObjectEach(data []byte, callback func(key []byte, value []byte, dataType ValueType, offset int) error, keys ...string) (err error) +``` +Needed for iterating object, accepts a callback function. Example: +```go +var handler func([]byte, []byte, jsonparser.ValueType, int) error +handler = func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error { + //do stuff here +} +jsonparser.ObjectEach(myJson, handler) +``` + + +### **`EachKey`** +```go +func EachKey(data []byte, cb func(idx int, value []byte, dataType jsonparser.ValueType, err error), paths ...[]string) +``` +When you need to read multiple keys, and you do not afraid of low-level API `EachKey` is your friend. It read payload only single time, and calls callback function once path is found. For example when you call multiple times `Get`, it has to process payload multiple times, each time you call it. Depending on payload `EachKey` can be multiple times faster than `Get`. Path can use nested keys as well! + +```go +paths := [][]string{ + []string{"uuid"}, + []string{"tz"}, + []string{"ua"}, + []string{"st"}, +} +var data SmallPayload + +jsonparser.EachKey(smallFixture, func(idx int, value []byte, vt jsonparser.ValueType, err error){ + switch idx { + case 0: + data.Uuid, _ = value + case 1: + v, _ := jsonparser.ParseInt(value) + data.Tz = int(v) + case 2: + data.Ua, _ = value + case 3: + v, _ := jsonparser.ParseInt(value) + data.St = int(v) + } +}, paths...) +``` + +### **`Set`** +```go +func Set(data []byte, setValue []byte, keys ...string) (value []byte, err error) +``` +Receives existing data structure, key path to set, and value to set at that key. *This functionality is experimental.* + +Returns: +* `value` - Pointer to original data structure with updated or added key value. +* `err` - If any parsing issue, it should return error. + +Accepts multiple keys to specify path to JSON value (in case of updating or creating nested structures). + +Note that keys can be an array indexes: `jsonparser.Set(data, []byte("http://github.com"), "person", "avatars", "[0]", "url")` + +### **`Delete`** +```go +func Delete(data []byte, keys ...string) value []byte +``` +Receives existing data structure, and key path to delete. *This functionality is experimental.* + +Returns: +* `value` - Pointer to original data structure with key path deleted if it can be found. If there is no key path, then the whole data structure is deleted. + +Accepts multiple keys to specify path to JSON value (in case of updating or creating nested structures). + +Note that keys can be an array indexes: `jsonparser.Delete(data, "person", "avatars", "[0]", "url")` + + +## What makes it so fast? +* It does not rely on `encoding/json`, `reflection` or `interface{}`, the only real package dependency is `bytes`. +* Operates with JSON payload on byte level, providing you pointers to the original data structure: no memory allocation. +* No automatic type conversions, by default everything is a []byte, but it provides you value type, so you can convert by yourself (there is few helpers included). +* Does not parse full record, only keys you specified + + +## Benchmarks + +There are 3 benchmark types, trying to simulate real-life usage for small, medium and large JSON payloads. +For each metric, the lower value is better. Time/op is in nanoseconds. Values better than standard encoding/json marked as bold text. +Benchmarks run on standard Linode 1024 box. + +Compared libraries: +* https://golang.org/pkg/encoding/json +* https://github.com/Jeffail/gabs +* https://github.com/a8m/djson +* https://github.com/bitly/go-simplejson +* https://github.com/antonholmquist/jason +* https://github.com/mreiferson/go-ujson +* https://github.com/ugorji/go/codec +* https://github.com/pquerna/ffjson +* https://github.com/mailru/easyjson +* https://github.com/buger/jsonparser + +#### TLDR +If you want to skip next sections we have 2 winner: `jsonparser` and `easyjson`. +`jsonparser` is up to 10 times faster than standard `encoding/json` package (depending on payload size and usage), and almost infinitely (literally) better in memory consumption because it operates with data on byte level, and provide direct slice pointers. +`easyjson` wins in CPU in medium tests and frankly i'm impressed with this package: it is remarkable results considering that it is almost drop-in replacement for `encoding/json` (require some code generation). + +It's hard to fully compare `jsonparser` and `easyjson` (or `ffson`), they a true parsers and fully process record, unlike `jsonparser` which parse only keys you specified. + +If you searching for replacement of `encoding/json` while keeping structs, `easyjson` is an amazing choice. If you want to process dynamic JSON, have memory constrains, or more control over your data you should try `jsonparser`. + +`jsonparser` performance heavily depends on usage, and it works best when you do not need to process full record, only some keys. The more calls you need to make, the slower it will be, in contrast `easyjson` (or `ffjson`, `encoding/json`) parser record only 1 time, and then you can make as many calls as you want. + +With great power comes great responsibility! :) + + +#### Small payload + +Each test processes 190 bytes of http log as a JSON record. +It should read multiple fields. +https://github.com/buger/jsonparser/blob/master/benchmark/benchmark_small_payload_test.go + +Library | time/op | bytes/op | allocs/op + ------ | ------- | -------- | ------- +encoding/json struct | 7879 | 880 | 18 +encoding/json interface{} | 8946 | 1521 | 38 +Jeffail/gabs | 10053 | 1649 | 46 +bitly/go-simplejson | 10128 | 2241 | 36 +antonholmquist/jason | 27152 | 7237 | 101 +github.com/ugorji/go/codec | 8806 | 2176 | 31 +mreiferson/go-ujson | **7008** | **1409** | 37 +a8m/djson | 3862 | 1249 | 30 +pquerna/ffjson | **3769** | **624** | **15** +mailru/easyjson | **2002** | **192** | **9** +buger/jsonparser | **1367** | **0** | **0** +buger/jsonparser (EachKey API) | **809** | **0** | **0** + +Winners are ffjson, easyjson and jsonparser, where jsonparser is up to 9.8x faster than encoding/json and 4.6x faster than ffjson, and slightly faster than easyjson. +If you look at memory allocation, jsonparser has no rivals, as it makes no data copy and operates with raw []byte structures and pointers to it. + +#### Medium payload + +Each test processes a 2.4kb JSON record (based on Clearbit API). +It should read multiple nested fields and 1 array. + +https://github.com/buger/jsonparser/blob/master/benchmark/benchmark_medium_payload_test.go + +| Library | time/op | bytes/op | allocs/op | +| ------- | ------- | -------- | --------- | +| encoding/json struct | 57749 | 1336 | 29 | +| encoding/json interface{} | 79297 | 10627 | 215 | +| Jeffail/gabs | 83807 | 11202 | 235 | +| bitly/go-simplejson | 88187 | 17187 | 220 | +| antonholmquist/jason | 94099 | 19013 | 247 | +| github.com/ugorji/go/codec | 114719 | 6712 | 152 | +| mreiferson/go-ujson | **56972** | 11547 | 270 | +| a8m/djson | 28525 | 10196 | 198 | +| pquerna/ffjson | **20298** | **856** | **20** | +| mailru/easyjson | **10512** | **336** | **12** | +| buger/jsonparser | **15955** | **0** | **0** | +| buger/jsonparser (EachKey API) | **8916** | **0** | **0** | + +The difference between ffjson and jsonparser in CPU usage is smaller, while the memory consumption difference is growing. On the other hand `easyjson` shows remarkable performance for medium payload. + +`gabs`, `go-simplejson` and `jason` are based on encoding/json and map[string]interface{} and actually only helpers for unstructured JSON, their performance correlate with `encoding/json interface{}`, and they will skip next round. +`go-ujson` while have its own parser, shows same performance as `encoding/json`, also skips next round. Same situation with `ugorji/go/codec`, but it showed unexpectedly bad performance for complex payloads. + + +#### Large payload + +Each test processes a 24kb JSON record (based on Discourse API) +It should read 2 arrays, and for each item in array get a few fields. +Basically it means processing a full JSON file. + +https://github.com/buger/jsonparser/blob/master/benchmark/benchmark_large_payload_test.go + +| Library | time/op | bytes/op | allocs/op | +| --- | --- | --- | --- | +| encoding/json struct | 748336 | 8272 | 307 | +| encoding/json interface{} | 1224271 | 215425 | 3395 | +| a8m/djson | 510082 | 213682 | 2845 | +| pquerna/ffjson | **312271** | **7792** | **298** | +| mailru/easyjson | **154186** | **6992** | **288** | +| buger/jsonparser | **85308** | **0** | **0** | + +`jsonparser` now is a winner, but do not forget that it is way more lightweight parser than `ffson` or `easyjson`, and they have to parser all the data, while `jsonparser` parse only what you need. All `ffjson`, `easysjon` and `jsonparser` have their own parsing code, and does not depend on `encoding/json` or `interface{}`, thats one of the reasons why they are so fast. `easyjson` also use a bit of `unsafe` package to reduce memory consuption (in theory it can lead to some unexpected GC issue, but i did not tested enough) + +Also last benchmark did not included `EachKey` test, because in this particular case we need to read lot of Array values, and using `ArrayEach` is more efficient. + +## Questions and support + +All bug-reports and suggestions should go though Github Issues. + +## Contributing + +1. Fork it +2. Create your feature branch (git checkout -b my-new-feature) +3. Commit your changes (git commit -am 'Added some feature') +4. Push to the branch (git push origin my-new-feature) +5. Create new Pull Request + +## Development + +All my development happens using Docker, and repo include some Make tasks to simplify development. + +* `make build` - builds docker image, usually can be called only once +* `make test` - run tests +* `make fmt` - run go fmt +* `make bench` - run benchmarks (if you need to run only single benchmark modify `BENCHMARK` variable in make file) +* `make profile` - runs benchmark and generate 3 files- `cpu.out`, `mem.mprof` and `benchmark.test` binary, which can be used for `go tool pprof` +* `make bash` - enter container (i use it for running `go tool pprof` above) diff --git a/vendor/github.com/buger/jsonparser/bytes.go b/vendor/github.com/buger/jsonparser/bytes.go new file mode 100644 index 0000000000..0bb0ff3956 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/bytes.go @@ -0,0 +1,47 @@ +package jsonparser + +import ( + bio "bytes" +) + +// minInt64 '-9223372036854775808' is the smallest representable number in int64 +const minInt64 = `9223372036854775808` + +// About 2x faster then strconv.ParseInt because it only supports base 10, which is enough for JSON +func parseInt(bytes []byte) (v int64, ok bool, overflow bool) { + if len(bytes) == 0 { + return 0, false, false + } + + var neg bool = false + if bytes[0] == '-' { + neg = true + bytes = bytes[1:] + } + + var b int64 = 0 + for _, c := range bytes { + if c >= '0' && c <= '9' { + b = (10 * v) + int64(c-'0') + } else { + return 0, false, false + } + if overflow = (b < v); overflow { + break + } + v = b + } + + if overflow { + if neg && bio.Equal(bytes, []byte(minInt64)) { + return b, true, false + } + return 0, false, true + } + + if neg { + return -v, true, false + } else { + return v, true, false + } +} diff --git a/vendor/github.com/buger/jsonparser/bytes_safe.go b/vendor/github.com/buger/jsonparser/bytes_safe.go new file mode 100644 index 0000000000..ff16a4a195 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/bytes_safe.go @@ -0,0 +1,25 @@ +// +build appengine appenginevm + +package jsonparser + +import ( + "strconv" +) + +// See fastbytes_unsafe.go for explanation on why *[]byte is used (signatures must be consistent with those in that file) + +func equalStr(b *[]byte, s string) bool { + return string(*b) == s +} + +func parseFloat(b *[]byte) (float64, error) { + return strconv.ParseFloat(string(*b), 64) +} + +func bytesToString(b *[]byte) string { + return string(*b) +} + +func StringToBytes(s string) []byte { + return []byte(s) +} diff --git a/vendor/github.com/buger/jsonparser/bytes_unsafe.go b/vendor/github.com/buger/jsonparser/bytes_unsafe.go new file mode 100644 index 0000000000..589fea87eb --- /dev/null +++ b/vendor/github.com/buger/jsonparser/bytes_unsafe.go @@ -0,0 +1,44 @@ +// +build !appengine,!appenginevm + +package jsonparser + +import ( + "reflect" + "strconv" + "unsafe" + "runtime" +) + +// +// The reason for using *[]byte rather than []byte in parameters is an optimization. As of Go 1.6, +// the compiler cannot perfectly inline the function when using a non-pointer slice. That is, +// the non-pointer []byte parameter version is slower than if its function body is manually +// inlined, whereas the pointer []byte version is equally fast to the manually inlined +// version. Instruction count in assembly taken from "go tool compile" confirms this difference. +// +// TODO: Remove hack after Go 1.7 release +// +func equalStr(b *[]byte, s string) bool { + return *(*string)(unsafe.Pointer(b)) == s +} + +func parseFloat(b *[]byte) (float64, error) { + return strconv.ParseFloat(*(*string)(unsafe.Pointer(b)), 64) +} + +// A hack until issue golang/go#2632 is fixed. +// See: https://github.com/golang/go/issues/2632 +func bytesToString(b *[]byte) string { + return *(*string)(unsafe.Pointer(b)) +} + +func StringToBytes(s string) []byte { + b := make([]byte, 0, 0) + bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) + bh.Data = sh.Data + bh.Cap = sh.Len + bh.Len = sh.Len + runtime.KeepAlive(s) + return b +} diff --git a/vendor/github.com/buger/jsonparser/escape.go b/vendor/github.com/buger/jsonparser/escape.go new file mode 100644 index 0000000000..49669b9420 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/escape.go @@ -0,0 +1,173 @@ +package jsonparser + +import ( + "bytes" + "unicode/utf8" +) + +// JSON Unicode stuff: see https://tools.ietf.org/html/rfc7159#section-7 + +const supplementalPlanesOffset = 0x10000 +const highSurrogateOffset = 0xD800 +const lowSurrogateOffset = 0xDC00 + +const basicMultilingualPlaneReservedOffset = 0xDFFF +const basicMultilingualPlaneOffset = 0xFFFF + +func combineUTF16Surrogates(high, low rune) rune { + return supplementalPlanesOffset + (high-highSurrogateOffset)<<10 + (low - lowSurrogateOffset) +} + +const badHex = -1 + +func h2I(c byte) int { + switch { + case c >= '0' && c <= '9': + return int(c - '0') + case c >= 'A' && c <= 'F': + return int(c - 'A' + 10) + case c >= 'a' && c <= 'f': + return int(c - 'a' + 10) + } + return badHex +} + +// decodeSingleUnicodeEscape decodes a single \uXXXX escape sequence. The prefix \u is assumed to be present and +// is not checked. +// In JSON, these escapes can either come alone or as part of "UTF16 surrogate pairs" that must be handled together. +// This function only handles one; decodeUnicodeEscape handles this more complex case. +func decodeSingleUnicodeEscape(in []byte) (rune, bool) { + // We need at least 6 characters total + if len(in) < 6 { + return utf8.RuneError, false + } + + // Convert hex to decimal + h1, h2, h3, h4 := h2I(in[2]), h2I(in[3]), h2I(in[4]), h2I(in[5]) + if h1 == badHex || h2 == badHex || h3 == badHex || h4 == badHex { + return utf8.RuneError, false + } + + // Compose the hex digits + return rune(h1<<12 + h2<<8 + h3<<4 + h4), true +} + +// isUTF16EncodedRune checks if a rune is in the range for non-BMP characters, +// which is used to describe UTF16 chars. +// Source: https://en.wikipedia.org/wiki/Plane_(Unicode)#Basic_Multilingual_Plane +func isUTF16EncodedRune(r rune) bool { + return highSurrogateOffset <= r && r <= basicMultilingualPlaneReservedOffset +} + +func decodeUnicodeEscape(in []byte) (rune, int) { + if r, ok := decodeSingleUnicodeEscape(in); !ok { + // Invalid Unicode escape + return utf8.RuneError, -1 + } else if r <= basicMultilingualPlaneOffset && !isUTF16EncodedRune(r) { + // Valid Unicode escape in Basic Multilingual Plane + return r, 6 + } else if r2, ok := decodeSingleUnicodeEscape(in[6:]); !ok { // Note: previous decodeSingleUnicodeEscape success guarantees at least 6 bytes remain + // UTF16 "high surrogate" without manditory valid following Unicode escape for the "low surrogate" + return utf8.RuneError, -1 + } else if r2 < lowSurrogateOffset { + // Invalid UTF16 "low surrogate" + return utf8.RuneError, -1 + } else { + // Valid UTF16 surrogate pair + return combineUTF16Surrogates(r, r2), 12 + } +} + +// backslashCharEscapeTable: when '\X' is found for some byte X, it is to be replaced with backslashCharEscapeTable[X] +var backslashCharEscapeTable = [...]byte{ + '"': '"', + '\\': '\\', + '/': '/', + 'b': '\b', + 'f': '\f', + 'n': '\n', + 'r': '\r', + 't': '\t', +} + +// unescapeToUTF8 unescapes the single escape sequence starting at 'in' into 'out' and returns +// how many characters were consumed from 'in' and emitted into 'out'. +// If a valid escape sequence does not appear as a prefix of 'in', (-1, -1) to signal the error. +func unescapeToUTF8(in, out []byte) (inLen int, outLen int) { + if len(in) < 2 || in[0] != '\\' { + // Invalid escape due to insufficient characters for any escape or no initial backslash + return -1, -1 + } + + // https://tools.ietf.org/html/rfc7159#section-7 + switch e := in[1]; e { + case '"', '\\', '/', 'b', 'f', 'n', 'r', 't': + // Valid basic 2-character escapes (use lookup table) + out[0] = backslashCharEscapeTable[e] + return 2, 1 + case 'u': + // Unicode escape + if r, inLen := decodeUnicodeEscape(in); inLen == -1 { + // Invalid Unicode escape + return -1, -1 + } else { + // Valid Unicode escape; re-encode as UTF8 + outLen := utf8.EncodeRune(out, r) + return inLen, outLen + } + } + + return -1, -1 +} + +// unescape unescapes the string contained in 'in' and returns it as a slice. +// If 'in' contains no escaped characters: +// Returns 'in'. +// Else, if 'out' is of sufficient capacity (guaranteed if cap(out) >= len(in)): +// 'out' is used to build the unescaped string and is returned with no extra allocation +// Else: +// A new slice is allocated and returned. +func Unescape(in, out []byte) ([]byte, error) { + firstBackslash := bytes.IndexByte(in, '\\') + if firstBackslash == -1 { + return in, nil + } + + // Get a buffer of sufficient size (allocate if needed) + if cap(out) < len(in) { + out = make([]byte, len(in)) + } else { + out = out[0:len(in)] + } + + // Copy the first sequence of unescaped bytes to the output and obtain a buffer pointer (subslice) + copy(out, in[:firstBackslash]) + in = in[firstBackslash:] + buf := out[firstBackslash:] + + for len(in) > 0 { + // Unescape the next escaped character + inLen, bufLen := unescapeToUTF8(in, buf) + if inLen == -1 { + return nil, MalformedStringEscapeError + } + + in = in[inLen:] + buf = buf[bufLen:] + + // Copy everything up until the next backslash + nextBackslash := bytes.IndexByte(in, '\\') + if nextBackslash == -1 { + copy(buf, in) + buf = buf[len(in):] + break + } else { + copy(buf, in[:nextBackslash]) + buf = buf[nextBackslash:] + in = in[nextBackslash:] + } + } + + // Trim the out buffer to the amount that was actually emitted + return out[:len(out)-len(buf)], nil +} diff --git a/vendor/github.com/buger/jsonparser/fuzz.go b/vendor/github.com/buger/jsonparser/fuzz.go new file mode 100644 index 0000000000..854bd11b2c --- /dev/null +++ b/vendor/github.com/buger/jsonparser/fuzz.go @@ -0,0 +1,117 @@ +package jsonparser + +func FuzzParseString(data []byte) int { + r, err := ParseString(data) + if err != nil || r == "" { + return 0 + } + return 1 +} + +func FuzzEachKey(data []byte) int { + paths := [][]string{ + {"name"}, + {"order"}, + {"nested", "a"}, + {"nested", "b"}, + {"nested2", "a"}, + {"nested", "nested3", "b"}, + {"arr", "[1]", "b"}, + {"arrInt", "[3]"}, + {"arrInt", "[5]"}, + {"nested"}, + {"arr", "["}, + {"a\n", "b\n"}, + } + EachKey(data, func(idx int, value []byte, vt ValueType, err error) {}, paths...) + return 1 +} + +func FuzzDelete(data []byte) int { + Delete(data, "test") + return 1 +} + +func FuzzSet(data []byte) int { + _, err := Set(data, []byte(`"new value"`), "test") + if err != nil { + return 0 + } + return 1 +} + +func FuzzObjectEach(data []byte) int { + _ = ObjectEach(data, func(key, value []byte, valueType ValueType, off int) error { + return nil + }) + return 1 +} + +func FuzzParseFloat(data []byte) int { + _, err := ParseFloat(data) + if err != nil { + return 0 + } + return 1 +} + +func FuzzParseInt(data []byte) int { + _, err := ParseInt(data) + if err != nil { + return 0 + } + return 1 +} + +func FuzzParseBool(data []byte) int { + _, err := ParseBoolean(data) + if err != nil { + return 0 + } + return 1 +} + +func FuzzTokenStart(data []byte) int { + _ = tokenStart(data) + return 1 +} + +func FuzzGetString(data []byte) int { + _, err := GetString(data, "test") + if err != nil { + return 0 + } + return 1 +} + +func FuzzGetFloat(data []byte) int { + _, err := GetFloat(data, "test") + if err != nil { + return 0 + } + return 1 +} + +func FuzzGetInt(data []byte) int { + _, err := GetInt(data, "test") + if err != nil { + return 0 + } + return 1 +} + +func FuzzGetBoolean(data []byte) int { + _, err := GetBoolean(data, "test") + if err != nil { + return 0 + } + return 1 +} + +func FuzzGetUnsafeString(data []byte) int { + _, err := GetUnsafeString(data, "test") + if err != nil { + return 0 + } + return 1 +} diff --git a/vendor/github.com/buger/jsonparser/oss-fuzz-build.sh b/vendor/github.com/buger/jsonparser/oss-fuzz-build.sh new file mode 100644 index 0000000000..c573b0e2d1 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/oss-fuzz-build.sh @@ -0,0 +1,47 @@ +#!/bin/bash -eu + +git clone https://github.com/dvyukov/go-fuzz-corpus +zip corpus.zip go-fuzz-corpus/json/corpus/* + +cp corpus.zip $OUT/fuzzparsestring_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzParseString fuzzparsestring + +cp corpus.zip $OUT/fuzzeachkey_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzEachKey fuzzeachkey + +cp corpus.zip $OUT/fuzzdelete_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzDelete fuzzdelete + +cp corpus.zip $OUT/fuzzset_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzSet fuzzset + +cp corpus.zip $OUT/fuzzobjecteach_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzObjectEach fuzzobjecteach + +cp corpus.zip $OUT/fuzzparsefloat_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzParseFloat fuzzparsefloat + +cp corpus.zip $OUT/fuzzparseint_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzParseInt fuzzparseint + +cp corpus.zip $OUT/fuzzparsebool_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzParseBool fuzzparsebool + +cp corpus.zip $OUT/fuzztokenstart_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzTokenStart fuzztokenstart + +cp corpus.zip $OUT/fuzzgetstring_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzGetString fuzzgetstring + +cp corpus.zip $OUT/fuzzgetfloat_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzGetFloat fuzzgetfloat + +cp corpus.zip $OUT/fuzzgetint_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzGetInt fuzzgetint + +cp corpus.zip $OUT/fuzzgetboolean_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzGetBoolean fuzzgetboolean + +cp corpus.zip $OUT/fuzzgetunsafestring_seed_corpus.zip +compile_go_fuzzer github.com/buger/jsonparser FuzzGetUnsafeString fuzzgetunsafestring + diff --git a/vendor/github.com/buger/jsonparser/parser.go b/vendor/github.com/buger/jsonparser/parser.go new file mode 100644 index 0000000000..14b80bc483 --- /dev/null +++ b/vendor/github.com/buger/jsonparser/parser.go @@ -0,0 +1,1283 @@ +package jsonparser + +import ( + "bytes" + "errors" + "fmt" + "strconv" +) + +// Errors +var ( + KeyPathNotFoundError = errors.New("Key path not found") + UnknownValueTypeError = errors.New("Unknown value type") + MalformedJsonError = errors.New("Malformed JSON error") + MalformedStringError = errors.New("Value is string, but can't find closing '\"' symbol") + MalformedArrayError = errors.New("Value is array, but can't find closing ']' symbol") + MalformedObjectError = errors.New("Value looks like object, but can't find closing '}' symbol") + MalformedValueError = errors.New("Value looks like Number/Boolean/None, but can't find its end: ',' or '}' symbol") + OverflowIntegerError = errors.New("Value is number, but overflowed while parsing") + MalformedStringEscapeError = errors.New("Encountered an invalid escape sequence in a string") +) + +// How much stack space to allocate for unescaping JSON strings; if a string longer +// than this needs to be escaped, it will result in a heap allocation +const unescapeStackBufSize = 64 + +func tokenEnd(data []byte) int { + for i, c := range data { + switch c { + case ' ', '\n', '\r', '\t', ',', '}', ']': + return i + } + } + + return len(data) +} + +func findTokenStart(data []byte, token byte) int { + for i := len(data) - 1; i >= 0; i-- { + switch data[i] { + case token: + return i + case '[', '{': + return 0 + } + } + + return 0 +} + +func findKeyStart(data []byte, key string) (int, error) { + i := 0 + ln := len(data) + if ln > 0 && (data[0] == '{' || data[0] == '[') { + i = 1 + } + var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings + + if ku, err := Unescape(StringToBytes(key), stackbuf[:]); err == nil { + key = bytesToString(&ku) + } + + for i < ln { + switch data[i] { + case '"': + i++ + keyBegin := i + + strEnd, keyEscaped := stringEnd(data[i:]) + if strEnd == -1 { + break + } + i += strEnd + keyEnd := i - 1 + + valueOffset := nextToken(data[i:]) + if valueOffset == -1 { + break + } + + i += valueOffset + + // if string is a key, and key level match + k := data[keyBegin:keyEnd] + // for unescape: if there are no escape sequences, this is cheap; if there are, it is a + // bit more expensive, but causes no allocations unless len(key) > unescapeStackBufSize + if keyEscaped { + if ku, err := Unescape(k, stackbuf[:]); err != nil { + break + } else { + k = ku + } + } + + if data[i] == ':' && len(key) == len(k) && bytesToString(&k) == key { + return keyBegin - 1, nil + } + + case '[': + end := blockEnd(data[i:], data[i], ']') + if end != -1 { + i = i + end + } + case '{': + end := blockEnd(data[i:], data[i], '}') + if end != -1 { + i = i + end + } + } + i++ + } + + return -1, KeyPathNotFoundError +} + +func tokenStart(data []byte) int { + for i := len(data) - 1; i >= 0; i-- { + switch data[i] { + case '\n', '\r', '\t', ',', '{', '[': + return i + } + } + + return 0 +} + +// Find position of next character which is not whitespace +func nextToken(data []byte) int { + for i, c := range data { + switch c { + case ' ', '\n', '\r', '\t': + continue + default: + return i + } + } + + return -1 +} + +// Find position of last character which is not whitespace +func lastToken(data []byte) int { + for i := len(data) - 1; i >= 0; i-- { + switch data[i] { + case ' ', '\n', '\r', '\t': + continue + default: + return i + } + } + + return -1 +} + +// Tries to find the end of string +// Support if string contains escaped quote symbols. +func stringEnd(data []byte) (int, bool) { + escaped := false + for i, c := range data { + if c == '"' { + if !escaped { + return i + 1, false + } else { + j := i - 1 + for { + if j < 0 || data[j] != '\\' { + return i + 1, true // even number of backslashes + } + j-- + if j < 0 || data[j] != '\\' { + break // odd number of backslashes + } + j-- + + } + } + } else if c == '\\' { + escaped = true + } + } + + return -1, escaped +} + +// Find end of the data structure, array or object. +// For array openSym and closeSym will be '[' and ']', for object '{' and '}' +func blockEnd(data []byte, openSym byte, closeSym byte) int { + level := 0 + i := 0 + ln := len(data) + + for i < ln { + switch data[i] { + case '"': // If inside string, skip it + se, _ := stringEnd(data[i+1:]) + if se == -1 { + return -1 + } + i += se + case openSym: // If open symbol, increase level + level++ + case closeSym: // If close symbol, increase level + level-- + + // If we have returned to the original level, we're done + if level == 0 { + return i + 1 + } + } + i++ + } + + return -1 +} + +func searchKeys(data []byte, keys ...string) int { + keyLevel := 0 + level := 0 + i := 0 + ln := len(data) + lk := len(keys) + lastMatched := true + + if lk == 0 { + return 0 + } + + var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings + + for i < ln { + switch data[i] { + case '"': + i++ + keyBegin := i + + strEnd, keyEscaped := stringEnd(data[i:]) + if strEnd == -1 { + return -1 + } + i += strEnd + keyEnd := i - 1 + + valueOffset := nextToken(data[i:]) + if valueOffset == -1 { + return -1 + } + + i += valueOffset + + // if string is a key + if data[i] == ':' { + if level < 1 { + return -1 + } + + key := data[keyBegin:keyEnd] + + // for unescape: if there are no escape sequences, this is cheap; if there are, it is a + // bit more expensive, but causes no allocations unless len(key) > unescapeStackBufSize + var keyUnesc []byte + if !keyEscaped { + keyUnesc = key + } else if ku, err := Unescape(key, stackbuf[:]); err != nil { + return -1 + } else { + keyUnesc = ku + } + + if level <= len(keys) { + if equalStr(&keyUnesc, keys[level-1]) { + lastMatched = true + + // if key level match + if keyLevel == level-1 { + keyLevel++ + // If we found all keys in path + if keyLevel == lk { + return i + 1 + } + } + } else { + lastMatched = false + } + } else { + return -1 + } + } else { + i-- + } + case '{': + + // in case parent key is matched then only we will increase the level otherwise can directly + // can move to the end of this block + if !lastMatched { + end := blockEnd(data[i:], '{', '}') + if end == -1 { + return -1 + } + i += end - 1 + } else { + level++ + } + case '}': + level-- + if level == keyLevel { + keyLevel-- + } + case '[': + // If we want to get array element by index + if keyLevel == level && keys[level][0] == '[' { + var keyLen = len(keys[level]) + if keyLen < 3 || keys[level][0] != '[' || keys[level][keyLen-1] != ']' { + return -1 + } + aIdx, err := strconv.Atoi(keys[level][1 : keyLen-1]) + if err != nil { + return -1 + } + var curIdx int + var valueFound []byte + var valueOffset int + var curI = i + ArrayEach(data[i:], func(value []byte, dataType ValueType, offset int, err error) { + if curIdx == aIdx { + valueFound = value + valueOffset = offset + if dataType == String { + valueOffset = valueOffset - 2 + valueFound = data[curI+valueOffset : curI+valueOffset+len(value)+2] + } + } + curIdx += 1 + }) + + if valueFound == nil { + return -1 + } else { + subIndex := searchKeys(valueFound, keys[level+1:]...) + if subIndex < 0 { + return -1 + } + return i + valueOffset + subIndex + } + } else { + // Do not search for keys inside arrays + if arraySkip := blockEnd(data[i:], '[', ']'); arraySkip == -1 { + return -1 + } else { + i += arraySkip - 1 + } + } + case ':': // If encountered, JSON data is malformed + return -1 + } + + i++ + } + + return -1 +} + +func sameTree(p1, p2 []string) bool { + minLen := len(p1) + if len(p2) < minLen { + minLen = len(p2) + } + + for pi_1, p_1 := range p1[:minLen] { + if p2[pi_1] != p_1 { + return false + } + } + + return true +} + +func EachKey(data []byte, cb func(int, []byte, ValueType, error), paths ...[]string) int { + var x struct{} + pathFlags := make([]bool, len(paths)) + var level, pathsMatched, i int + ln := len(data) + + var maxPath int + for _, p := range paths { + if len(p) > maxPath { + maxPath = len(p) + } + } + + pathsBuf := make([]string, maxPath) + + for i < ln { + switch data[i] { + case '"': + i++ + keyBegin := i + + strEnd, keyEscaped := stringEnd(data[i:]) + if strEnd == -1 { + return -1 + } + i += strEnd + + keyEnd := i - 1 + + valueOffset := nextToken(data[i:]) + if valueOffset == -1 { + return -1 + } + + i += valueOffset + + // if string is a key, and key level match + if data[i] == ':' { + match := -1 + key := data[keyBegin:keyEnd] + + // for unescape: if there are no escape sequences, this is cheap; if there are, it is a + // bit more expensive, but causes no allocations unless len(key) > unescapeStackBufSize + var keyUnesc []byte + if !keyEscaped { + keyUnesc = key + } else { + var stackbuf [unescapeStackBufSize]byte + if ku, err := Unescape(key, stackbuf[:]); err != nil { + return -1 + } else { + keyUnesc = ku + } + } + + if maxPath >= level { + if level < 1 { + cb(-1, nil, Unknown, MalformedJsonError) + return -1 + } + + pathsBuf[level-1] = bytesToString(&keyUnesc) + for pi, p := range paths { + if len(p) != level || pathFlags[pi] || !equalStr(&keyUnesc, p[level-1]) || !sameTree(p, pathsBuf[:level]) { + continue + } + + match = pi + + pathsMatched++ + pathFlags[pi] = true + + v, dt, _, e := Get(data[i+1:]) + cb(pi, v, dt, e) + + if pathsMatched == len(paths) { + break + } + } + if pathsMatched == len(paths) { + return i + } + } + + if match == -1 { + tokenOffset := nextToken(data[i+1:]) + i += tokenOffset + + if data[i] == '{' { + blockSkip := blockEnd(data[i:], '{', '}') + i += blockSkip + 1 + } + } + + if i < ln { + switch data[i] { + case '{', '}', '[', '"': + i-- + } + } + } else { + i-- + } + case '{': + level++ + case '}': + level-- + case '[': + var ok bool + arrIdxFlags := make(map[int]struct{}) + pIdxFlags := make([]bool, len(paths)) + + if level < 0 { + cb(-1, nil, Unknown, MalformedJsonError) + return -1 + } + + for pi, p := range paths { + if len(p) < level+1 || pathFlags[pi] || p[level][0] != '[' || !sameTree(p, pathsBuf[:level]) { + continue + } + if len(p[level]) >= 2 { + aIdx, _ := strconv.Atoi(p[level][1 : len(p[level])-1]) + arrIdxFlags[aIdx] = x + pIdxFlags[pi] = true + } + } + + if len(arrIdxFlags) > 0 { + level++ + + var curIdx int + arrOff, _ := ArrayEach(data[i:], func(value []byte, dataType ValueType, offset int, err error) { + if _, ok = arrIdxFlags[curIdx]; ok { + for pi, p := range paths { + if pIdxFlags[pi] { + aIdx, _ := strconv.Atoi(p[level-1][1 : len(p[level-1])-1]) + + if curIdx == aIdx { + of := searchKeys(value, p[level:]...) + + pathsMatched++ + pathFlags[pi] = true + + if of != -1 { + v, dt, _, e := Get(value[of:]) + cb(pi, v, dt, e) + } + } + } + } + } + + curIdx += 1 + }) + + if pathsMatched == len(paths) { + return i + } + + i += arrOff - 1 + } else { + // Do not search for keys inside arrays + if arraySkip := blockEnd(data[i:], '[', ']'); arraySkip == -1 { + return -1 + } else { + i += arraySkip - 1 + } + } + case ']': + level-- + } + + i++ + } + + return -1 +} + +// Data types available in valid JSON data. +type ValueType int + +const ( + NotExist = ValueType(iota) + String + Number + Object + Array + Boolean + Null + Unknown +) + +func (vt ValueType) String() string { + switch vt { + case NotExist: + return "non-existent" + case String: + return "string" + case Number: + return "number" + case Object: + return "object" + case Array: + return "array" + case Boolean: + return "boolean" + case Null: + return "null" + default: + return "unknown" + } +} + +var ( + trueLiteral = []byte("true") + falseLiteral = []byte("false") + nullLiteral = []byte("null") +) + +func createInsertComponent(keys []string, setValue []byte, comma, object bool) []byte { + isIndex := string(keys[0][0]) == "[" + offset := 0 + lk := calcAllocateSpace(keys, setValue, comma, object) + buffer := make([]byte, lk, lk) + if comma { + offset += WriteToBuffer(buffer[offset:], ",") + } + if isIndex && !comma { + offset += WriteToBuffer(buffer[offset:], "[") + } else { + if object { + offset += WriteToBuffer(buffer[offset:], "{") + } + if !isIndex { + offset += WriteToBuffer(buffer[offset:], "\"") + offset += WriteToBuffer(buffer[offset:], keys[0]) + offset += WriteToBuffer(buffer[offset:], "\":") + } + } + + for i := 1; i < len(keys); i++ { + if string(keys[i][0]) == "[" { + offset += WriteToBuffer(buffer[offset:], "[") + } else { + offset += WriteToBuffer(buffer[offset:], "{\"") + offset += WriteToBuffer(buffer[offset:], keys[i]) + offset += WriteToBuffer(buffer[offset:], "\":") + } + } + offset += WriteToBuffer(buffer[offset:], string(setValue)) + for i := len(keys) - 1; i > 0; i-- { + if string(keys[i][0]) == "[" { + offset += WriteToBuffer(buffer[offset:], "]") + } else { + offset += WriteToBuffer(buffer[offset:], "}") + } + } + if isIndex && !comma { + offset += WriteToBuffer(buffer[offset:], "]") + } + if object && !isIndex { + offset += WriteToBuffer(buffer[offset:], "}") + } + return buffer +} + +func calcAllocateSpace(keys []string, setValue []byte, comma, object bool) int { + isIndex := string(keys[0][0]) == "[" + lk := 0 + if comma { + // , + lk += 1 + } + if isIndex && !comma { + // [] + lk += 2 + } else { + if object { + // { + lk += 1 + } + if !isIndex { + // "keys[0]" + lk += len(keys[0]) + 3 + } + } + + + lk += len(setValue) + for i := 1; i < len(keys); i++ { + if string(keys[i][0]) == "[" { + // [] + lk += 2 + } else { + // {"keys[i]":setValue} + lk += len(keys[i]) + 5 + } + } + + if object && !isIndex { + // } + lk += 1 + } + + return lk +} + +func WriteToBuffer(buffer []byte, str string) int { + copy(buffer, str) + return len(str) +} + +/* + +Del - Receives existing data structure, path to delete. + +Returns: +`data` - return modified data + +*/ +func Delete(data []byte, keys ...string) []byte { + lk := len(keys) + if lk == 0 { + return data[:0] + } + + array := false + if len(keys[lk-1]) > 0 && string(keys[lk-1][0]) == "[" { + array = true + } + + var startOffset, keyOffset int + endOffset := len(data) + var err error + if !array { + if len(keys) > 1 { + _, _, startOffset, endOffset, err = internalGet(data, keys[:lk-1]...) + if err == KeyPathNotFoundError { + // problem parsing the data + return data + } + } + + keyOffset, err = findKeyStart(data[startOffset:endOffset], keys[lk-1]) + if err == KeyPathNotFoundError { + // problem parsing the data + return data + } + keyOffset += startOffset + _, _, _, subEndOffset, _ := internalGet(data[startOffset:endOffset], keys[lk-1]) + endOffset = startOffset + subEndOffset + tokEnd := tokenEnd(data[endOffset:]) + tokStart := findTokenStart(data[:keyOffset], ","[0]) + + if data[endOffset+tokEnd] == ","[0] { + endOffset += tokEnd + 1 + } else if data[endOffset+tokEnd] == " "[0] && len(data) > endOffset+tokEnd+1 && data[endOffset+tokEnd+1] == ","[0] { + endOffset += tokEnd + 2 + } else if data[endOffset+tokEnd] == "}"[0] && data[tokStart] == ","[0] { + keyOffset = tokStart + } + } else { + _, _, keyOffset, endOffset, err = internalGet(data, keys...) + if err == KeyPathNotFoundError { + // problem parsing the data + return data + } + + tokEnd := tokenEnd(data[endOffset:]) + tokStart := findTokenStart(data[:keyOffset], ","[0]) + + if data[endOffset+tokEnd] == ","[0] { + endOffset += tokEnd + 1 + } else if data[endOffset+tokEnd] == "]"[0] && data[tokStart] == ","[0] { + keyOffset = tokStart + } + } + + // We need to remove remaining trailing comma if we delete las element in the object + prevTok := lastToken(data[:keyOffset]) + remainedValue := data[endOffset:] + + var newOffset int + if nextToken(remainedValue) > -1 && remainedValue[nextToken(remainedValue)] == '}' && data[prevTok] == ',' { + newOffset = prevTok + } else { + newOffset = prevTok + 1 + } + + // We have to make a copy here if we don't want to mangle the original data, because byte slices are + // accessed by reference and not by value + dataCopy := make([]byte, len(data)) + copy(dataCopy, data) + data = append(dataCopy[:newOffset], dataCopy[endOffset:]...) + + return data +} + +/* + +Set - Receives existing data structure, path to set, and data to set at that key. + +Returns: +`value` - modified byte array +`err` - On any parsing error + +*/ +func Set(data []byte, setValue []byte, keys ...string) (value []byte, err error) { + // ensure keys are set + if len(keys) == 0 { + return nil, KeyPathNotFoundError + } + + _, _, startOffset, endOffset, err := internalGet(data, keys...) + if err != nil { + if err != KeyPathNotFoundError { + // problem parsing the data + return nil, err + } + // full path doesnt exist + // does any subpath exist? + var depth int + for i := range keys { + _, _, start, end, sErr := internalGet(data, keys[:i+1]...) + if sErr != nil { + break + } else { + endOffset = end + startOffset = start + depth++ + } + } + comma := true + object := false + if endOffset == -1 { + firstToken := nextToken(data) + // We can't set a top-level key if data isn't an object + if firstToken < 0 || data[firstToken] != '{' { + return nil, KeyPathNotFoundError + } + // Don't need a comma if the input is an empty object + secondToken := firstToken + 1 + nextToken(data[firstToken+1:]) + if data[secondToken] == '}' { + comma = false + } + // Set the top level key at the end (accounting for any trailing whitespace) + // This assumes last token is valid like '}', could check and return error + endOffset = lastToken(data) + } + depthOffset := endOffset + if depth != 0 { + // if subpath is a non-empty object, add to it + // or if subpath is a non-empty array, add to it + if (data[startOffset] == '{' && data[startOffset+1+nextToken(data[startOffset+1:])] != '}') || + (data[startOffset] == '[' && data[startOffset+1+nextToken(data[startOffset+1:])] == '{') && keys[depth:][0][0] == 91 { + depthOffset-- + startOffset = depthOffset + // otherwise, over-write it with a new object + } else { + comma = false + object = true + } + } else { + startOffset = depthOffset + } + value = append(data[:startOffset], append(createInsertComponent(keys[depth:], setValue, comma, object), data[depthOffset:]...)...) + } else { + // path currently exists + startComponent := data[:startOffset] + endComponent := data[endOffset:] + + value = make([]byte, len(startComponent)+len(endComponent)+len(setValue)) + newEndOffset := startOffset + len(setValue) + copy(value[0:startOffset], startComponent) + copy(value[startOffset:newEndOffset], setValue) + copy(value[newEndOffset:], endComponent) + } + return value, nil +} + +func getType(data []byte, offset int) ([]byte, ValueType, int, error) { + var dataType ValueType + endOffset := offset + + // if string value + if data[offset] == '"' { + dataType = String + if idx, _ := stringEnd(data[offset+1:]); idx != -1 { + endOffset += idx + 1 + } else { + return nil, dataType, offset, MalformedStringError + } + } else if data[offset] == '[' { // if array value + dataType = Array + // break label, for stopping nested loops + endOffset = blockEnd(data[offset:], '[', ']') + + if endOffset == -1 { + return nil, dataType, offset, MalformedArrayError + } + + endOffset += offset + } else if data[offset] == '{' { // if object value + dataType = Object + // break label, for stopping nested loops + endOffset = blockEnd(data[offset:], '{', '}') + + if endOffset == -1 { + return nil, dataType, offset, MalformedObjectError + } + + endOffset += offset + } else { + // Number, Boolean or None + end := tokenEnd(data[endOffset:]) + + if end == -1 { + return nil, dataType, offset, MalformedValueError + } + + value := data[offset : endOffset+end] + + switch data[offset] { + case 't', 'f': // true or false + if bytes.Equal(value, trueLiteral) || bytes.Equal(value, falseLiteral) { + dataType = Boolean + } else { + return nil, Unknown, offset, UnknownValueTypeError + } + case 'u', 'n': // undefined or null + if bytes.Equal(value, nullLiteral) { + dataType = Null + } else { + return nil, Unknown, offset, UnknownValueTypeError + } + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': + dataType = Number + default: + return nil, Unknown, offset, UnknownValueTypeError + } + + endOffset += end + } + return data[offset:endOffset], dataType, endOffset, nil +} + +/* +Get - Receives data structure, and key path to extract value from. + +Returns: +`value` - Pointer to original data structure containing key value, or just empty slice if nothing found or error +`dataType` - Can be: `NotExist`, `String`, `Number`, `Object`, `Array`, `Boolean` or `Null` +`offset` - Offset from provided data structure where key value ends. Used mostly internally, for example for `ArrayEach` helper. +`err` - If key not found or any other parsing issue it should return error. If key not found it also sets `dataType` to `NotExist` + +Accept multiple keys to specify path to JSON value (in case of quering nested structures). +If no keys provided it will try to extract closest JSON value (simple ones or object/array), useful for reading streams or arrays, see `ArrayEach` implementation. +*/ +func Get(data []byte, keys ...string) (value []byte, dataType ValueType, offset int, err error) { + a, b, _, d, e := internalGet(data, keys...) + return a, b, d, e +} + +func internalGet(data []byte, keys ...string) (value []byte, dataType ValueType, offset, endOffset int, err error) { + if len(keys) > 0 { + if offset = searchKeys(data, keys...); offset == -1 { + return nil, NotExist, -1, -1, KeyPathNotFoundError + } + } + + // Go to closest value + nO := nextToken(data[offset:]) + if nO == -1 { + return nil, NotExist, offset, -1, MalformedJsonError + } + + offset += nO + value, dataType, endOffset, err = getType(data, offset) + if err != nil { + return value, dataType, offset, endOffset, err + } + + // Strip quotes from string values + if dataType == String { + value = value[1 : len(value)-1] + } + + return value[:len(value):len(value)], dataType, offset, endOffset, nil +} + +// ArrayEach is used when iterating arrays, accepts a callback function with the same return arguments as `Get`. +func ArrayEach(data []byte, cb func(value []byte, dataType ValueType, offset int, err error), keys ...string) (offset int, err error) { + if len(data) == 0 { + return -1, MalformedObjectError + } + + nT := nextToken(data) + if nT == -1 { + return -1, MalformedJsonError + } + + offset = nT + 1 + + if len(keys) > 0 { + if offset = searchKeys(data, keys...); offset == -1 { + return offset, KeyPathNotFoundError + } + + // Go to closest value + nO := nextToken(data[offset:]) + if nO == -1 { + return offset, MalformedJsonError + } + + offset += nO + + if data[offset] != '[' { + return offset, MalformedArrayError + } + + offset++ + } + + nO := nextToken(data[offset:]) + if nO == -1 { + return offset, MalformedJsonError + } + + offset += nO + + if data[offset] == ']' { + return offset, nil + } + + for true { + v, t, o, e := Get(data[offset:]) + + if e != nil { + return offset, e + } + + if o == 0 { + break + } + + if t != NotExist { + cb(v, t, offset+o-len(v), e) + } + + if e != nil { + break + } + + offset += o + + skipToToken := nextToken(data[offset:]) + if skipToToken == -1 { + return offset, MalformedArrayError + } + offset += skipToToken + + if data[offset] == ']' { + break + } + + if data[offset] != ',' { + return offset, MalformedArrayError + } + + offset++ + } + + return offset, nil +} + +// ObjectEach iterates over the key-value pairs of a JSON object, invoking a given callback for each such entry +func ObjectEach(data []byte, callback func(key []byte, value []byte, dataType ValueType, offset int) error, keys ...string) (err error) { + offset := 0 + + // Descend to the desired key, if requested + if len(keys) > 0 { + if off := searchKeys(data, keys...); off == -1 { + return KeyPathNotFoundError + } else { + offset = off + } + } + + // Validate and skip past opening brace + if off := nextToken(data[offset:]); off == -1 { + return MalformedObjectError + } else if offset += off; data[offset] != '{' { + return MalformedObjectError + } else { + offset++ + } + + // Skip to the first token inside the object, or stop if we find the ending brace + if off := nextToken(data[offset:]); off == -1 { + return MalformedJsonError + } else if offset += off; data[offset] == '}' { + return nil + } + + // Loop pre-condition: data[offset] points to what should be either the next entry's key, or the closing brace (if it's anything else, the JSON is malformed) + for offset < len(data) { + // Step 1: find the next key + var key []byte + + // Check what the the next token is: start of string, end of object, or something else (error) + switch data[offset] { + case '"': + offset++ // accept as string and skip opening quote + case '}': + return nil // we found the end of the object; stop and return success + default: + return MalformedObjectError + } + + // Find the end of the key string + var keyEscaped bool + if off, esc := stringEnd(data[offset:]); off == -1 { + return MalformedJsonError + } else { + key, keyEscaped = data[offset:offset+off-1], esc + offset += off + } + + // Unescape the string if needed + if keyEscaped { + var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings + if keyUnescaped, err := Unescape(key, stackbuf[:]); err != nil { + return MalformedStringEscapeError + } else { + key = keyUnescaped + } + } + + // Step 2: skip the colon + if off := nextToken(data[offset:]); off == -1 { + return MalformedJsonError + } else if offset += off; data[offset] != ':' { + return MalformedJsonError + } else { + offset++ + } + + // Step 3: find the associated value, then invoke the callback + if value, valueType, off, err := Get(data[offset:]); err != nil { + return err + } else if err := callback(key, value, valueType, offset+off); err != nil { // Invoke the callback here! + return err + } else { + offset += off + } + + // Step 4: skip over the next comma to the following token, or stop if we hit the ending brace + if off := nextToken(data[offset:]); off == -1 { + return MalformedArrayError + } else { + offset += off + switch data[offset] { + case '}': + return nil // Stop if we hit the close brace + case ',': + offset++ // Ignore the comma + default: + return MalformedObjectError + } + } + + // Skip to the next token after the comma + if off := nextToken(data[offset:]); off == -1 { + return MalformedArrayError + } else { + offset += off + } + } + + return MalformedObjectError // we shouldn't get here; it's expected that we will return via finding the ending brace +} + +// GetUnsafeString returns the value retrieved by `Get`, use creates string without memory allocation by mapping string to slice memory. It does not handle escape symbols. +func GetUnsafeString(data []byte, keys ...string) (val string, err error) { + v, _, _, e := Get(data, keys...) + + if e != nil { + return "", e + } + + return bytesToString(&v), nil +} + +// GetString returns the value retrieved by `Get`, cast to a string if possible, trying to properly handle escape and utf8 symbols +// If key data type do not match, it will return an error. +func GetString(data []byte, keys ...string) (val string, err error) { + v, t, _, e := Get(data, keys...) + + if e != nil { + return "", e + } + + if t != String { + return "", fmt.Errorf("Value is not a string: %s", string(v)) + } + + // If no escapes return raw content + if bytes.IndexByte(v, '\\') == -1 { + return string(v), nil + } + + return ParseString(v) +} + +// GetFloat returns the value retrieved by `Get`, cast to a float64 if possible. +// The offset is the same as in `Get`. +// If key data type do not match, it will return an error. +func GetFloat(data []byte, keys ...string) (val float64, err error) { + v, t, _, e := Get(data, keys...) + + if e != nil { + return 0, e + } + + if t != Number { + return 0, fmt.Errorf("Value is not a number: %s", string(v)) + } + + return ParseFloat(v) +} + +// GetInt returns the value retrieved by `Get`, cast to a int64 if possible. +// If key data type do not match, it will return an error. +func GetInt(data []byte, keys ...string) (val int64, err error) { + v, t, _, e := Get(data, keys...) + + if e != nil { + return 0, e + } + + if t != Number { + return 0, fmt.Errorf("Value is not a number: %s", string(v)) + } + + return ParseInt(v) +} + +// GetBoolean returns the value retrieved by `Get`, cast to a bool if possible. +// The offset is the same as in `Get`. +// If key data type do not match, it will return error. +func GetBoolean(data []byte, keys ...string) (val bool, err error) { + v, t, _, e := Get(data, keys...) + + if e != nil { + return false, e + } + + if t != Boolean { + return false, fmt.Errorf("Value is not a boolean: %s", string(v)) + } + + return ParseBoolean(v) +} + +// ParseBoolean parses a Boolean ValueType into a Go bool (not particularly useful, but here for completeness) +func ParseBoolean(b []byte) (bool, error) { + switch { + case bytes.Equal(b, trueLiteral): + return true, nil + case bytes.Equal(b, falseLiteral): + return false, nil + default: + return false, MalformedValueError + } +} + +// ParseString parses a String ValueType into a Go string (the main parsing work is unescaping the JSON string) +func ParseString(b []byte) (string, error) { + var stackbuf [unescapeStackBufSize]byte // stack-allocated array for allocation-free unescaping of small strings + if bU, err := Unescape(b, stackbuf[:]); err != nil { + return "", MalformedValueError + } else { + return string(bU), nil + } +} + +// ParseNumber parses a Number ValueType into a Go float64 +func ParseFloat(b []byte) (float64, error) { + if v, err := parseFloat(&b); err != nil { + return 0, MalformedValueError + } else { + return v, nil + } +} + +// ParseInt parses a Number ValueType into a Go int64 +func ParseInt(b []byte) (int64, error) { + if v, ok, overflow := parseInt(b); !ok { + if overflow { + return 0, OverflowIntegerError + } + return 0, MalformedValueError + } else { + return v, nil + } +} diff --git a/vendor/github.com/invopop/jsonschema/.gitignore b/vendor/github.com/invopop/jsonschema/.gitignore new file mode 100644 index 0000000000..8ef0e14fc7 --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/.gitignore @@ -0,0 +1,2 @@ +vendor/ +.idea/ diff --git a/vendor/github.com/invopop/jsonschema/.golangci.yml b/vendor/github.com/invopop/jsonschema/.golangci.yml new file mode 100644 index 0000000000..b89b2e124d --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/.golangci.yml @@ -0,0 +1,69 @@ +run: + tests: true + max-same-issues: 50 + +output: + print-issued-lines: false + +linters: + enable: + - gocyclo + - gocritic + - goconst + - dupl + - unconvert + - goimports + - unused + - govet + - nakedret + - errcheck + - revive + - ineffassign + - goconst + - unparam + - gofmt + +linters-settings: + vet: + check-shadowing: true + use-installed-packages: true + dupl: + threshold: 100 + goconst: + min-len: 8 + min-occurrences: 3 + gocyclo: + min-complexity: 20 + gocritic: + disabled-checks: + - ifElseChain + gofmt: + rewrite-rules: + - pattern: "interface{}" + replacement: "any" + - pattern: "a[b:len(a)]" + replacement: "a[b:]" + +issues: + max-per-linter: 0 + max-same: 0 + exclude-dirs: + - resources + - old + exclude-files: + - cmd/protopkg/main.go + exclude-use-default: false + exclude: + # Captured by errcheck. + - "^(G104|G204):" + # Very commonly not checked. + - 'Error return value of .(.*\.Help|.*\.MarkFlagRequired|(os\.)?std(out|err)\..*|.*Close|.*Flush|os\.Remove(All)?|.*Print(f|ln|)|os\.(Un)?Setenv). is not checked' + # Weird error only seen on Kochiku... + - "internal error: no range for" + - 'exported method `.*\.(MarshalJSON|UnmarshalJSON|URN|Payload|GoString|Close|Provides|Requires|ExcludeFromHash|MarshalText|UnmarshalText|Description|Check|Poll|Severity)` should have comment or be unexported' + - "composite literal uses unkeyed fields" + - 'declaration of "err" shadows declaration' + - "by other packages, and that stutters" + - "Potential file inclusion via variable" + - "at least one file in a package should have a package comment" + - "bad syntax for struct tag pair" diff --git a/vendor/github.com/invopop/jsonschema/COPYING b/vendor/github.com/invopop/jsonschema/COPYING new file mode 100644 index 0000000000..2993ec085d --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/COPYING @@ -0,0 +1,19 @@ +Copyright (C) 2014 Alec Thomas + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/invopop/jsonschema/README.md b/vendor/github.com/invopop/jsonschema/README.md new file mode 100644 index 0000000000..27b362e1dd --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/README.md @@ -0,0 +1,374 @@ +# Go JSON Schema Reflection + +[![Lint](https://github.com/invopop/jsonschema/actions/workflows/lint.yaml/badge.svg)](https://github.com/invopop/jsonschema/actions/workflows/lint.yaml) +[![Test Go](https://github.com/invopop/jsonschema/actions/workflows/test.yaml/badge.svg)](https://github.com/invopop/jsonschema/actions/workflows/test.yaml) +[![Go Report Card](https://goreportcard.com/badge/github.com/invopop/jsonschema)](https://goreportcard.com/report/github.com/invopop/jsonschema) +[![GoDoc](https://godoc.org/github.com/invopop/jsonschema?status.svg)](https://godoc.org/github.com/invopop/jsonschema) +[![codecov](https://codecov.io/gh/invopop/jsonschema/graph/badge.svg?token=JMEB8W8GNZ)](https://codecov.io/gh/invopop/jsonschema) +![Latest Tag](https://img.shields.io/github/v/tag/invopop/jsonschema) + +This package can be used to generate [JSON Schemas](http://json-schema.org/latest/json-schema-validation.html) from Go types through reflection. + +- Supports arbitrarily complex types, including `interface{}`, maps, slices, etc. +- Supports json-schema features such as minLength, maxLength, pattern, format, etc. +- Supports simple string and numeric enums. +- Supports custom property fields via the `jsonschema_extras` struct tag. + +This repository is a fork of the original [jsonschema](https://github.com/alecthomas/jsonschema) by [@alecthomas](https://github.com/alecthomas). At [Invopop](https://invopop.com) we use jsonschema as a cornerstone in our [GOBL library](https://github.com/invopop/gobl), and wanted to be able to continue building and adding features without taking up Alec's time. There have been a few significant changes that probably mean this version is a not compatible with with Alec's: + +- The original was stuck on the draft-04 version of JSON Schema, we've now moved to the latest JSON Schema Draft 2020-12. +- Schema IDs are added automatically from the current Go package's URL in order to be unique, and can be disabled with the `Anonymous` option. +- Support for the `FullyQualifyTypeName` option has been removed. If you have conflicts, you should use multiple schema files with different IDs, set the `DoNotReference` option to true to hide definitions completely, or add your own naming strategy using the `Namer` property. +- Support for `yaml` tags and related options has been dropped for the sake of simplification. There were a [few inconsistencies](https://github.com/invopop/jsonschema/pull/21) around this that have now been fixed. + +## Versions + +This project is still under v0 scheme, as per Go convention, breaking changes are likely. Please pin go modules to version tags or branches, and reach out if you think something can be improved. + +Go version >= 1.18 is required as generics are now being used. + +## Example + +The following Go type: + +```go +type TestUser struct { + ID int `json:"id"` + Name string `json:"name" jsonschema:"title=the name,description=The name of a friend,example=joe,example=lucy,default=alex"` + Friends []int `json:"friends,omitempty" jsonschema_description:"The list of IDs, omitted when empty"` + Tags map[string]interface{} `json:"tags,omitempty" jsonschema_extras:"a=b,foo=bar,foo=bar1"` + BirthDate time.Time `json:"birth_date,omitempty" jsonschema:"oneof_required=date"` + YearOfBirth string `json:"year_of_birth,omitempty" jsonschema:"oneof_required=year"` + Metadata interface{} `json:"metadata,omitempty" jsonschema:"oneof_type=string;array"` + FavColor string `json:"fav_color,omitempty" jsonschema:"enum=red,enum=green,enum=blue"` +} +``` + +Results in following JSON Schema: + +```go +jsonschema.Reflect(&TestUser{}) +``` + +```json +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://github.com/invopop/jsonschema_test/test-user", + "$ref": "#/$defs/TestUser", + "$defs": { + "TestUser": { + "oneOf": [ + { + "required": ["birth_date"], + "title": "date" + }, + { + "required": ["year_of_birth"], + "title": "year" + } + ], + "properties": { + "id": { + "type": "integer" + }, + "name": { + "type": "string", + "title": "the name", + "description": "The name of a friend", + "default": "alex", + "examples": ["joe", "lucy"] + }, + "friends": { + "items": { + "type": "integer" + }, + "type": "array", + "description": "The list of IDs, omitted when empty" + }, + "tags": { + "type": "object", + "a": "b", + "foo": ["bar", "bar1"] + }, + "birth_date": { + "type": "string", + "format": "date-time" + }, + "year_of_birth": { + "type": "string" + }, + "metadata": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "array" + } + ] + }, + "fav_color": { + "type": "string", + "enum": ["red", "green", "blue"] + } + }, + "additionalProperties": false, + "type": "object", + "required": ["id", "name"] + } + } +} +``` + +## YAML + +Support for `yaml` tags has now been removed. If you feel very strongly about this, we've opened a discussion to hear your comments: https://github.com/invopop/jsonschema/discussions/28 + +The recommended approach if you need to deal with YAML data is to first convert to JSON. The [invopop/yaml](https://github.com/invopop/yaml) library will make this trivial. + +## Configurable behaviour + +The behaviour of the schema generator can be altered with parameters when a `jsonschema.Reflector` +instance is created. + +### ExpandedStruct + +If set to `true`, makes the top level struct not to reference itself in the definitions. But type passed should be a struct type. + +eg. + +```go +type GrandfatherType struct { + FamilyName string `json:"family_name" jsonschema:"required"` +} + +type SomeBaseType struct { + SomeBaseProperty int `json:"some_base_property"` + // The jsonschema required tag is nonsensical for private and ignored properties. + // Their presence here tests that the fields *will not* be required in the output + // schema, even if they are tagged required. + somePrivateBaseProperty string `json:"i_am_private" jsonschema:"required"` + SomeIgnoredBaseProperty string `json:"-" jsonschema:"required"` + SomeSchemaIgnoredProperty string `jsonschema:"-,required"` + SomeUntaggedBaseProperty bool `jsonschema:"required"` + someUnexportedUntaggedBaseProperty bool + Grandfather GrandfatherType `json:"grand"` +} +``` + +will output: + +```json +{ + "$schema": "http://json-schema.org/draft/2020-12/schema", + "required": ["some_base_property", "grand", "SomeUntaggedBaseProperty"], + "properties": { + "SomeUntaggedBaseProperty": { + "type": "boolean" + }, + "grand": { + "$schema": "http://json-schema.org/draft/2020-12/schema", + "$ref": "#/definitions/GrandfatherType" + }, + "some_base_property": { + "type": "integer" + } + }, + "type": "object", + "$defs": { + "GrandfatherType": { + "required": ["family_name"], + "properties": { + "family_name": { + "type": "string" + } + }, + "additionalProperties": false, + "type": "object" + } + } +} +``` + +### Using Go Comments + +Writing a good schema with descriptions inside tags can become cumbersome and tedious, especially if you already have some Go comments around your types and field definitions. If you'd like to take advantage of these existing comments, you can use the `AddGoComments(base, path string)` method that forms part of the reflector to parse your go files and automatically generate a dictionary of Go import paths, types, and fields, to individual comments. These will then be used automatically as description fields, and can be overridden with a manual definition if needed. + +Take a simplified example of a User struct which for the sake of simplicity we assume is defined inside this package: + +```go +package main + +// User is used as a base to provide tests for comments. +type User struct { + // Unique sequential identifier. + ID int `json:"id" jsonschema:"required"` + // Name of the user + Name string `json:"name"` +} +``` + +To get the comments provided into your JSON schema, use a regular `Reflector` and add the go code using an import module URL and path. Fully qualified go module paths cannot be determined reliably by the `go/parser` library, so we need to introduce this manually: + +```go +r := new(Reflector) +if err := r.AddGoComments("github.com/invopop/jsonschema", "./"); err != nil { + // deal with error +} +s := r.Reflect(&User{}) +// output +``` + +Expect the results to be similar to: + +```json +{ + "$schema": "http://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/User", + "$defs": { + "User": { + "required": ["id"], + "properties": { + "id": { + "type": "integer", + "description": "Unique sequential identifier." + }, + "name": { + "type": "string", + "description": "Name of the user" + } + }, + "additionalProperties": false, + "type": "object", + "description": "User is used as a base to provide tests for comments." + } + } +} +``` + +### Custom Key Naming + +In some situations, the keys actually used to write files are different from Go structs'. + +This is often the case when writing a configuration file to YAML or JSON from a Go struct, or when returning a JSON response for a Web API: APIs typically use snake_case, while Go uses PascalCase. + +You can pass a `func(string) string` function to `Reflector`'s `KeyNamer` option to map Go field names to JSON key names and reflect the aforementioned transformations, without having to specify `json:"..."` on every struct field. + +For example, consider the following struct + +```go +type User struct { + GivenName string + PasswordSalted []byte `json:"salted_password"` +} +``` + +We can transform field names to snake_case in the generated JSON schema: + +```go +r := new(jsonschema.Reflector) +r.KeyNamer = strcase.SnakeCase // from package github.com/stoewer/go-strcase + +r.Reflect(&User{}) +``` + +Will yield + +```diff + { + "$schema": "http://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/User", + "$defs": { + "User": { + "properties": { +- "GivenName": { ++ "given_name": { + "type": "string" + }, + "salted_password": { + "type": "string", + "contentEncoding": "base64" + } + }, + "additionalProperties": false, + "type": "object", +- "required": ["GivenName", "salted_password"] ++ "required": ["given_name", "salted_password"] + } + } + } +``` + +As you can see, if a field name has a `json:""` tag set, the `key` argument to `KeyNamer` will have the value of that tag. + +### Custom Type Definitions + +Sometimes it can be useful to have custom JSON Marshal and Unmarshal methods in your structs that automatically convert for example a string into an object. + +This library will recognize and attempt to call four different methods that help you adjust schemas to your specific needs: + +- `JSONSchema() *Schema` - will prevent auto-generation of the schema so that you can provide your own definition. +- `JSONSchemaExtend(schema *jsonschema.Schema)` - will be called _after_ the schema has been generated, allowing you to add or manipulate the fields easily. +- `JSONSchemaAlias() any` - is called when reflecting the type of object and allows for an alternative to be used instead. +- `JSONSchemaProperty(prop string) any` - will be called for every property inside a struct giving you the chance to provide an alternative object to convert into a schema. + +Note that all of these methods **must** be defined on a non-pointer object for them to be called. + +Take the following simplified example of a `CompactDate` that only includes the Year and Month: + +```go +type CompactDate struct { + Year int + Month int +} + +func (d *CompactDate) UnmarshalJSON(data []byte) error { + if len(data) != 9 { + return errors.New("invalid compact date length") + } + var err error + d.Year, err = strconv.Atoi(string(data[1:5])) + if err != nil { + return err + } + d.Month, err = strconv.Atoi(string(data[7:8])) + if err != nil { + return err + } + return nil +} + +func (d *CompactDate) MarshalJSON() ([]byte, error) { + buf := new(bytes.Buffer) + buf.WriteByte('"') + buf.WriteString(fmt.Sprintf("%d-%02d", d.Year, d.Month)) + buf.WriteByte('"') + return buf.Bytes(), nil +} + +func (CompactDate) JSONSchema() *Schema { + return &Schema{ + Type: "string", + Title: "Compact Date", + Description: "Short date that only includes year and month", + Pattern: "^[0-9]{4}-[0-1][0-9]$", + } +} +``` + +The resulting schema generated for this struct would look like: + +```json +{ + "$schema": "http://json-schema.org/draft/2020-12/schema", + "$ref": "#/$defs/CompactDate", + "$defs": { + "CompactDate": { + "pattern": "^[0-9]{4}-[0-1][0-9]$", + "type": "string", + "title": "Compact Date", + "description": "Short date that only includes year and month" + } + } +} +``` diff --git a/vendor/github.com/invopop/jsonschema/id.go b/vendor/github.com/invopop/jsonschema/id.go new file mode 100644 index 0000000000..73fafb38d0 --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/id.go @@ -0,0 +1,76 @@ +package jsonschema + +import ( + "errors" + "fmt" + "net/url" + "strings" +) + +// ID represents a Schema ID type which should always be a URI. +// See draft-bhutton-json-schema-00 section 8.2.1 +type ID string + +// EmptyID is used to explicitly define an ID with no value. +const EmptyID ID = "" + +// Validate is used to check if the ID looks like a proper schema. +// This is done by parsing the ID as a URL and checking it has all the +// relevant parts. +func (id ID) Validate() error { + u, err := url.Parse(id.String()) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + if u.Hostname() == "" { + return errors.New("missing hostname") + } + if !strings.Contains(u.Hostname(), ".") { + return errors.New("hostname does not look valid") + } + if u.Path == "" { + return errors.New("path is expected") + } + if u.Scheme != "https" && u.Scheme != "http" { + return errors.New("unexpected schema") + } + return nil +} + +// Anchor sets the anchor part of the schema URI. +func (id ID) Anchor(name string) ID { + b := id.Base() + return ID(b.String() + "#" + name) +} + +// Def adds or replaces a definition identifier. +func (id ID) Def(name string) ID { + b := id.Base() + return ID(b.String() + "#/$defs/" + name) +} + +// Add appends the provided path to the id, and removes any +// anchor data that might be there. +func (id ID) Add(path string) ID { + b := id.Base() + if !strings.HasPrefix(path, "/") { + path = "/" + path + } + return ID(b.String() + path) +} + +// Base removes any anchor information from the schema +func (id ID) Base() ID { + s := id.String() + i := strings.LastIndex(s, "#") + if i != -1 { + s = s[0:i] + } + s = strings.TrimRight(s, "/") + return ID(s) +} + +// String provides string version of ID +func (id ID) String() string { + return string(id) +} diff --git a/vendor/github.com/invopop/jsonschema/reflect.go b/vendor/github.com/invopop/jsonschema/reflect.go new file mode 100644 index 0000000000..73ce7e465b --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/reflect.go @@ -0,0 +1,1148 @@ +// Package jsonschema uses reflection to generate JSON Schemas from Go types [1]. +// +// If json tags are present on struct fields, they will be used to infer +// property names and if a property is required (omitempty is present). +// +// [1] http://json-schema.org/latest/json-schema-validation.html +package jsonschema + +import ( + "bytes" + "encoding/json" + "net" + "net/url" + "reflect" + "strconv" + "strings" + "time" +) + +// customSchemaImpl is used to detect if the type provides it's own +// custom Schema Type definition to use instead. Very useful for situations +// where there are custom JSON Marshal and Unmarshal methods. +type customSchemaImpl interface { + JSONSchema() *Schema +} + +// Function to be run after the schema has been generated. +// this will let you modify a schema afterwards +type extendSchemaImpl interface { + JSONSchemaExtend(*Schema) +} + +// If the object to be reflected defines a `JSONSchemaAlias` method, its type will +// be used instead of the original type. +type aliasSchemaImpl interface { + JSONSchemaAlias() any +} + +// If an object to be reflected defines a `JSONSchemaPropertyAlias` method, +// it will be called for each property to determine if another object +// should be used for the contents. +type propertyAliasSchemaImpl interface { + JSONSchemaProperty(prop string) any +} + +var customAliasSchema = reflect.TypeOf((*aliasSchemaImpl)(nil)).Elem() +var customPropertyAliasSchema = reflect.TypeOf((*propertyAliasSchemaImpl)(nil)).Elem() + +var customType = reflect.TypeOf((*customSchemaImpl)(nil)).Elem() +var extendType = reflect.TypeOf((*extendSchemaImpl)(nil)).Elem() + +// customSchemaGetFieldDocString +type customSchemaGetFieldDocString interface { + GetFieldDocString(fieldName string) string +} + +type customGetFieldDocString func(fieldName string) string + +var customStructGetFieldDocString = reflect.TypeOf((*customSchemaGetFieldDocString)(nil)).Elem() + +// Reflect reflects to Schema from a value using the default Reflector +func Reflect(v any) *Schema { + return ReflectFromType(reflect.TypeOf(v)) +} + +// ReflectFromType generates root schema using the default Reflector +func ReflectFromType(t reflect.Type) *Schema { + r := &Reflector{} + return r.ReflectFromType(t) +} + +// A Reflector reflects values into a Schema. +type Reflector struct { + // BaseSchemaID defines the URI that will be used as a base to determine Schema + // IDs for models. For example, a base Schema ID of `https://invopop.com/schemas` + // when defined with a struct called `User{}`, will result in a schema with an + // ID set to `https://invopop.com/schemas/user`. + // + // If no `BaseSchemaID` is provided, we'll take the type's complete package path + // and use that as a base instead. Set `Anonymous` to try if you do not want to + // include a schema ID. + BaseSchemaID ID + + // Anonymous when true will hide the auto-generated Schema ID and provide what is + // known as an "anonymous schema". As a rule, this is not recommended. + Anonymous bool + + // AssignAnchor when true will use the original struct's name as an anchor inside + // every definition, including the root schema. These can be useful for having a + // reference to the original struct's name in CamelCase instead of the snake-case used + // by default for URI compatibility. + // + // Anchors do not appear to be widely used out in the wild, so at this time the + // anchors themselves will not be used inside generated schema. + AssignAnchor bool + + // AllowAdditionalProperties will cause the Reflector to generate a schema + // without additionalProperties set to 'false' for all struct types. This means + // the presence of additional keys in JSON objects will not cause validation + // to fail. Note said additional keys will simply be dropped when the + // validated JSON is unmarshaled. + AllowAdditionalProperties bool + + // RequiredFromJSONSchemaTags will cause the Reflector to generate a schema + // that requires any key tagged with `jsonschema:required`, overriding the + // default of requiring any key *not* tagged with `json:,omitempty`. + RequiredFromJSONSchemaTags bool + + // Do not reference definitions. This will remove the top-level $defs map and + // instead cause the entire structure of types to be output in one tree. The + // list of type definitions (`$defs`) will not be included. + DoNotReference bool + + // ExpandedStruct when true will include the reflected type's definition in the + // root as opposed to a definition with a reference. + ExpandedStruct bool + + // FieldNameTag will change the tag used to get field names. json tags are used by default. + FieldNameTag string + + // IgnoredTypes defines a slice of types that should be ignored in the schema, + // switching to just allowing additional properties instead. + IgnoredTypes []any + + // Lookup allows a function to be defined that will provide a custom mapping of + // types to Schema IDs. This allows existing schema documents to be referenced + // by their ID instead of being embedded into the current schema definitions. + // Reflected types will never be pointers, only underlying elements. + Lookup func(reflect.Type) ID + + // Mapper is a function that can be used to map custom Go types to jsonschema schemas. + Mapper func(reflect.Type) *Schema + + // Namer allows customizing of type names. The default is to use the type's name + // provided by the reflect package. + Namer func(reflect.Type) string + + // KeyNamer allows customizing of key names. + // The default is to use the key's name as is, or the json tag if present. + // If a json tag is present, KeyNamer will receive the tag's name as an argument, not the original key name. + KeyNamer func(string) string + + // AdditionalFields allows adding structfields for a given type + AdditionalFields func(reflect.Type) []reflect.StructField + + // LookupComment allows customizing comment lookup. Given a reflect.Type and optionally + // a field name, it should return the comment string associated with this type or field. + // + // If the field name is empty, it should return the type's comment; otherwise, the field's + // comment should be returned. If no comment is found, an empty string should be returned. + // + // When set, this function is called before the below CommentMap lookup mechanism. However, + // if it returns an empty string, the CommentMap is still consulted. + LookupComment func(reflect.Type, string) string + + // CommentMap is a dictionary of fully qualified go types and fields to comment + // strings that will be used if a description has not already been provided in + // the tags. Types and fields are added to the package path using "." as a + // separator. + // + // Type descriptions should be defined like: + // + // map[string]string{"github.com/invopop/jsonschema.Reflector": "A Reflector reflects values into a Schema."} + // + // And Fields defined as: + // + // map[string]string{"github.com/invopop/jsonschema.Reflector.DoNotReference": "Do not reference definitions."} + // + // See also: AddGoComments, LookupComment + CommentMap map[string]string +} + +// Reflect reflects to Schema from a value. +func (r *Reflector) Reflect(v any) *Schema { + return r.ReflectFromType(reflect.TypeOf(v)) +} + +// ReflectFromType generates root schema +func (r *Reflector) ReflectFromType(t reflect.Type) *Schema { + if t.Kind() == reflect.Ptr { + t = t.Elem() // re-assign from pointer + } + + name := r.typeName(t) + + s := new(Schema) + definitions := Definitions{} + s.Definitions = definitions + bs := r.reflectTypeToSchemaWithID(definitions, t) + if r.ExpandedStruct { + *s = *definitions[name] + delete(definitions, name) + } else { + *s = *bs + } + + // Attempt to set the schema ID + if !r.Anonymous && s.ID == EmptyID { + baseSchemaID := r.BaseSchemaID + if baseSchemaID == EmptyID { + id := ID("https://" + t.PkgPath()) + if err := id.Validate(); err == nil { + // it's okay to silently ignore URL errors + baseSchemaID = id + } + } + if baseSchemaID != EmptyID { + s.ID = baseSchemaID.Add(ToSnakeCase(name)) + } + } + + s.Version = Version + if !r.DoNotReference { + s.Definitions = definitions + } + + return s +} + +// Available Go defined types for JSON Schema Validation. +// RFC draft-wright-json-schema-validation-00, section 7.3 +var ( + timeType = reflect.TypeOf(time.Time{}) // date-time RFC section 7.3.1 + ipType = reflect.TypeOf(net.IP{}) // ipv4 and ipv6 RFC section 7.3.4, 7.3.5 + uriType = reflect.TypeOf(url.URL{}) // uri RFC section 7.3.6 +) + +// Byte slices will be encoded as base64 +var byteSliceType = reflect.TypeOf([]byte(nil)) + +// Except for json.RawMessage +var rawMessageType = reflect.TypeOf(json.RawMessage{}) + +// Go code generated from protobuf enum types should fulfil this interface. +type protoEnum interface { + EnumDescriptor() ([]byte, []int) +} + +var protoEnumType = reflect.TypeOf((*protoEnum)(nil)).Elem() + +// SetBaseSchemaID is a helper use to be able to set the reflectors base +// schema ID from a string as opposed to then ID instance. +func (r *Reflector) SetBaseSchemaID(id string) { + r.BaseSchemaID = ID(id) +} + +func (r *Reflector) refOrReflectTypeToSchema(definitions Definitions, t reflect.Type) *Schema { + id := r.lookupID(t) + if id != EmptyID { + return &Schema{ + Ref: id.String(), + } + } + + // Already added to definitions? + if def := r.refDefinition(definitions, t); def != nil { + return def + } + + return r.reflectTypeToSchemaWithID(definitions, t) +} + +func (r *Reflector) reflectTypeToSchemaWithID(defs Definitions, t reflect.Type) *Schema { + s := r.reflectTypeToSchema(defs, t) + if s != nil { + if r.Lookup != nil { + id := r.Lookup(t) + if id != EmptyID { + s.ID = id + } + } + } + return s +} + +func (r *Reflector) reflectTypeToSchema(definitions Definitions, t reflect.Type) *Schema { + // only try to reflect non-pointers + if t.Kind() == reflect.Ptr { + return r.refOrReflectTypeToSchema(definitions, t.Elem()) + } + + // Check if the there is an alias method that provides an object + // that we should use instead of this one. + if t.Implements(customAliasSchema) { + v := reflect.New(t) + o := v.Interface().(aliasSchemaImpl) + t = reflect.TypeOf(o.JSONSchemaAlias()) + return r.refOrReflectTypeToSchema(definitions, t) + } + + // Do any pre-definitions exist? + if r.Mapper != nil { + if t := r.Mapper(t); t != nil { + return t + } + } + if rt := r.reflectCustomSchema(definitions, t); rt != nil { + return rt + } + + // Prepare a base to which details can be added + st := new(Schema) + + // jsonpb will marshal protobuf enum options as either strings or integers. + // It will unmarshal either. + if t.Implements(protoEnumType) { + st.OneOf = []*Schema{ + {Type: "string"}, + {Type: "integer"}, + } + return st + } + + // Defined format types for JSON Schema Validation + // RFC draft-wright-json-schema-validation-00, section 7.3 + // TODO email RFC section 7.3.2, hostname RFC section 7.3.3, uriref RFC section 7.3.7 + if t == ipType { + // TODO differentiate ipv4 and ipv6 RFC section 7.3.4, 7.3.5 + st.Type = "string" + st.Format = "ipv4" + return st + } + + switch t.Kind() { + case reflect.Struct: + r.reflectStruct(definitions, t, st) + + case reflect.Slice, reflect.Array: + r.reflectSliceOrArray(definitions, t, st) + + case reflect.Map: + r.reflectMap(definitions, t, st) + + case reflect.Interface: + // empty + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + st.Type = "integer" + + case reflect.Float32, reflect.Float64: + st.Type = "number" + + case reflect.Bool: + st.Type = "boolean" + + case reflect.String: + st.Type = "string" + + default: + panic("unsupported type " + t.String()) + } + + r.reflectSchemaExtend(definitions, t, st) + + // Always try to reference the definition which may have just been created + if def := r.refDefinition(definitions, t); def != nil { + return def + } + + return st +} + +func (r *Reflector) reflectCustomSchema(definitions Definitions, t reflect.Type) *Schema { + if t.Kind() == reflect.Ptr { + return r.reflectCustomSchema(definitions, t.Elem()) + } + + if t.Implements(customType) { + v := reflect.New(t) + o := v.Interface().(customSchemaImpl) + st := o.JSONSchema() + r.addDefinition(definitions, t, st) + if ref := r.refDefinition(definitions, t); ref != nil { + return ref + } + return st + } + + return nil +} + +func (r *Reflector) reflectSchemaExtend(definitions Definitions, t reflect.Type, s *Schema) *Schema { + if t.Implements(extendType) { + v := reflect.New(t) + o := v.Interface().(extendSchemaImpl) + o.JSONSchemaExtend(s) + if ref := r.refDefinition(definitions, t); ref != nil { + return ref + } + } + + return s +} + +func (r *Reflector) reflectSliceOrArray(definitions Definitions, t reflect.Type, st *Schema) { + if t == rawMessageType { + return + } + + r.addDefinition(definitions, t, st) + + if st.Description == "" { + st.Description = r.lookupComment(t, "") + } + + if t.Kind() == reflect.Array { + l := uint64(t.Len()) + st.MinItems = &l + st.MaxItems = &l + } + if t.Kind() == reflect.Slice && t.Elem() == byteSliceType.Elem() { + st.Type = "string" + // NOTE: ContentMediaType is not set here + st.ContentEncoding = "base64" + } else { + st.Type = "array" + st.Items = r.refOrReflectTypeToSchema(definitions, t.Elem()) + } +} + +func (r *Reflector) reflectMap(definitions Definitions, t reflect.Type, st *Schema) { + r.addDefinition(definitions, t, st) + + st.Type = "object" + if st.Description == "" { + st.Description = r.lookupComment(t, "") + } + + switch t.Key().Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + st.PatternProperties = map[string]*Schema{ + "^[0-9]+$": r.refOrReflectTypeToSchema(definitions, t.Elem()), + } + st.AdditionalProperties = FalseSchema + return + } + if t.Elem().Kind() != reflect.Interface { + st.AdditionalProperties = r.refOrReflectTypeToSchema(definitions, t.Elem()) + } +} + +// Reflects a struct to a JSON Schema type. +func (r *Reflector) reflectStruct(definitions Definitions, t reflect.Type, s *Schema) { + // Handle special types + switch t { + case timeType: // date-time RFC section 7.3.1 + s.Type = "string" + s.Format = "date-time" + return + case uriType: // uri RFC section 7.3.6 + s.Type = "string" + s.Format = "uri" + return + } + + r.addDefinition(definitions, t, s) + s.Type = "object" + s.Properties = NewProperties() + s.Description = r.lookupComment(t, "") + if r.AssignAnchor { + s.Anchor = t.Name() + } + if !r.AllowAdditionalProperties && s.AdditionalProperties == nil { + s.AdditionalProperties = FalseSchema + } + + ignored := false + for _, it := range r.IgnoredTypes { + if reflect.TypeOf(it) == t { + ignored = true + break + } + } + if !ignored { + r.reflectStructFields(s, definitions, t) + } +} + +func (r *Reflector) reflectStructFields(st *Schema, definitions Definitions, t reflect.Type) { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return + } + + var getFieldDocString customGetFieldDocString + if t.Implements(customStructGetFieldDocString) { + v := reflect.New(t) + o := v.Interface().(customSchemaGetFieldDocString) + getFieldDocString = o.GetFieldDocString + } + + customPropertyMethod := func(string) any { + return nil + } + if t.Implements(customPropertyAliasSchema) { + v := reflect.New(t) + o := v.Interface().(propertyAliasSchemaImpl) + customPropertyMethod = o.JSONSchemaProperty + } + + handleField := func(f reflect.StructField) { + name, shouldEmbed, required, nullable := r.reflectFieldName(f) + // if anonymous and exported type should be processed recursively + // current type should inherit properties of anonymous one + if name == "" { + if shouldEmbed { + r.reflectStructFields(st, definitions, f.Type) + } + return + } + + // If a JSONSchemaAlias(prop string) method is defined, attempt to use + // the provided object's type instead of the field's type. + var property *Schema + if alias := customPropertyMethod(name); alias != nil { + property = r.refOrReflectTypeToSchema(definitions, reflect.TypeOf(alias)) + } else { + property = r.refOrReflectTypeToSchema(definitions, f.Type) + } + + property.structKeywordsFromTags(f, st, name) + if property.Description == "" { + property.Description = r.lookupComment(t, f.Name) + } + if getFieldDocString != nil { + property.Description = getFieldDocString(f.Name) + } + + if nullable { + property = &Schema{ + OneOf: []*Schema{ + property, + { + Type: "null", + }, + }, + } + } + + st.Properties.Set(name, property) + if required { + st.Required = appendUniqueString(st.Required, name) + } + } + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + handleField(f) + } + if r.AdditionalFields != nil { + if af := r.AdditionalFields(t); af != nil { + for _, sf := range af { + handleField(sf) + } + } + } +} + +func appendUniqueString(base []string, value string) []string { + for _, v := range base { + if v == value { + return base + } + } + return append(base, value) +} + +// addDefinition will append the provided schema. If needed, an ID and anchor will also be added. +func (r *Reflector) addDefinition(definitions Definitions, t reflect.Type, s *Schema) { + name := r.typeName(t) + if name == "" { + return + } + definitions[name] = s +} + +// refDefinition will provide a schema with a reference to an existing definition. +func (r *Reflector) refDefinition(definitions Definitions, t reflect.Type) *Schema { + if r.DoNotReference { + return nil + } + name := r.typeName(t) + if name == "" { + return nil + } + if _, ok := definitions[name]; !ok { + return nil + } + return &Schema{ + Ref: "#/$defs/" + name, + } +} + +func (r *Reflector) lookupID(t reflect.Type) ID { + if r.Lookup != nil { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return r.Lookup(t) + + } + return EmptyID +} + +func (t *Schema) structKeywordsFromTags(f reflect.StructField, parent *Schema, propertyName string) { + t.Description = f.Tag.Get("jsonschema_description") + + tags := splitOnUnescapedCommas(f.Tag.Get("jsonschema")) + tags = t.genericKeywords(tags, parent, propertyName) + + switch t.Type { + case "string": + t.stringKeywords(tags) + case "number": + t.numericalKeywords(tags) + case "integer": + t.numericalKeywords(tags) + case "array": + t.arrayKeywords(tags) + case "boolean": + t.booleanKeywords(tags) + } + extras := strings.Split(f.Tag.Get("jsonschema_extras"), ",") + t.extraKeywords(extras) +} + +// read struct tags for generic keywords +func (t *Schema) genericKeywords(tags []string, parent *Schema, propertyName string) []string { //nolint:gocyclo + unprocessed := make([]string, 0, len(tags)) + for _, tag := range tags { + nameValue := strings.SplitN(tag, "=", 2) + if len(nameValue) == 2 { + name, val := nameValue[0], nameValue[1] + switch name { + case "title": + t.Title = val + case "description": + t.Description = val + case "type": + t.Type = val + case "anchor": + t.Anchor = val + case "oneof_required": + var typeFound *Schema + for i := range parent.OneOf { + if parent.OneOf[i].Title == nameValue[1] { + typeFound = parent.OneOf[i] + } + } + if typeFound == nil { + typeFound = &Schema{ + Title: nameValue[1], + Required: []string{}, + } + parent.OneOf = append(parent.OneOf, typeFound) + } + typeFound.Required = append(typeFound.Required, propertyName) + case "anyof_required": + var typeFound *Schema + for i := range parent.AnyOf { + if parent.AnyOf[i].Title == nameValue[1] { + typeFound = parent.AnyOf[i] + } + } + if typeFound == nil { + typeFound = &Schema{ + Title: nameValue[1], + Required: []string{}, + } + parent.AnyOf = append(parent.AnyOf, typeFound) + } + typeFound.Required = append(typeFound.Required, propertyName) + case "oneof_ref": + subSchema := t + if t.Items != nil { + subSchema = t.Items + } + if subSchema.OneOf == nil { + subSchema.OneOf = make([]*Schema, 0, 1) + } + subSchema.Ref = "" + refs := strings.Split(nameValue[1], ";") + for _, r := range refs { + subSchema.OneOf = append(subSchema.OneOf, &Schema{ + Ref: r, + }) + } + case "oneof_type": + if t.OneOf == nil { + t.OneOf = make([]*Schema, 0, 1) + } + t.Type = "" + types := strings.Split(nameValue[1], ";") + for _, ty := range types { + t.OneOf = append(t.OneOf, &Schema{ + Type: ty, + }) + } + case "anyof_ref": + subSchema := t + if t.Items != nil { + subSchema = t.Items + } + if subSchema.AnyOf == nil { + subSchema.AnyOf = make([]*Schema, 0, 1) + } + subSchema.Ref = "" + refs := strings.Split(nameValue[1], ";") + for _, r := range refs { + subSchema.AnyOf = append(subSchema.AnyOf, &Schema{ + Ref: r, + }) + } + case "anyof_type": + if t.AnyOf == nil { + t.AnyOf = make([]*Schema, 0, 1) + } + t.Type = "" + types := strings.Split(nameValue[1], ";") + for _, ty := range types { + t.AnyOf = append(t.AnyOf, &Schema{ + Type: ty, + }) + } + default: + unprocessed = append(unprocessed, tag) + } + } + } + return unprocessed +} + +// read struct tags for boolean type keywords +func (t *Schema) booleanKeywords(tags []string) { + for _, tag := range tags { + nameValue := strings.Split(tag, "=") + if len(nameValue) != 2 { + continue + } + name, val := nameValue[0], nameValue[1] + if name == "default" { + if val == "true" { + t.Default = true + } else if val == "false" { + t.Default = false + } + } + } +} + +// read struct tags for string type keywords +func (t *Schema) stringKeywords(tags []string) { + for _, tag := range tags { + nameValue := strings.SplitN(tag, "=", 2) + if len(nameValue) == 2 { + name, val := nameValue[0], nameValue[1] + switch name { + case "minLength": + t.MinLength = parseUint(val) + case "maxLength": + t.MaxLength = parseUint(val) + case "pattern": + t.Pattern = val + case "format": + t.Format = val + case "readOnly": + i, _ := strconv.ParseBool(val) + t.ReadOnly = i + case "writeOnly": + i, _ := strconv.ParseBool(val) + t.WriteOnly = i + case "default": + t.Default = val + case "example": + t.Examples = append(t.Examples, val) + case "enum": + t.Enum = append(t.Enum, val) + } + } + } +} + +// read struct tags for numerical type keywords +func (t *Schema) numericalKeywords(tags []string) { + for _, tag := range tags { + nameValue := strings.Split(tag, "=") + if len(nameValue) == 2 { + name, val := nameValue[0], nameValue[1] + switch name { + case "multipleOf": + t.MultipleOf, _ = toJSONNumber(val) + case "minimum": + t.Minimum, _ = toJSONNumber(val) + case "maximum": + t.Maximum, _ = toJSONNumber(val) + case "exclusiveMaximum": + t.ExclusiveMaximum, _ = toJSONNumber(val) + case "exclusiveMinimum": + t.ExclusiveMinimum, _ = toJSONNumber(val) + case "default": + if num, ok := toJSONNumber(val); ok { + t.Default = num + } + case "example": + if num, ok := toJSONNumber(val); ok { + t.Examples = append(t.Examples, num) + } + case "enum": + if num, ok := toJSONNumber(val); ok { + t.Enum = append(t.Enum, num) + } + } + } + } +} + +// read struct tags for object type keywords +// func (t *Type) objectKeywords(tags []string) { +// for _, tag := range tags{ +// nameValue := strings.Split(tag, "=") +// name, val := nameValue[0], nameValue[1] +// switch name{ +// case "dependencies": +// t.Dependencies = val +// break; +// case "patternProperties": +// t.PatternProperties = val +// break; +// } +// } +// } + +// read struct tags for array type keywords +func (t *Schema) arrayKeywords(tags []string) { + var defaultValues []any + + unprocessed := make([]string, 0, len(tags)) + for _, tag := range tags { + nameValue := strings.Split(tag, "=") + if len(nameValue) == 2 { + name, val := nameValue[0], nameValue[1] + switch name { + case "minItems": + t.MinItems = parseUint(val) + case "maxItems": + t.MaxItems = parseUint(val) + case "uniqueItems": + t.UniqueItems = true + case "default": + defaultValues = append(defaultValues, val) + case "format": + t.Items.Format = val + case "pattern": + t.Items.Pattern = val + default: + unprocessed = append(unprocessed, tag) // left for further processing by underlying type + } + } + } + if len(defaultValues) > 0 { + t.Default = defaultValues + } + + if len(unprocessed) == 0 { + // we don't have anything else to process + return + } + + switch t.Items.Type { + case "string": + t.Items.stringKeywords(unprocessed) + case "number": + t.Items.numericalKeywords(unprocessed) + case "integer": + t.Items.numericalKeywords(unprocessed) + case "array": + // explicitly don't support traversal for the [][]..., as it's unclear where the array tags belong + case "boolean": + t.Items.booleanKeywords(unprocessed) + } +} + +func (t *Schema) extraKeywords(tags []string) { + for _, tag := range tags { + nameValue := strings.SplitN(tag, "=", 2) + if len(nameValue) == 2 { + t.setExtra(nameValue[0], nameValue[1]) + } + } +} + +func (t *Schema) setExtra(key, val string) { + if t.Extras == nil { + t.Extras = map[string]any{} + } + if existingVal, ok := t.Extras[key]; ok { + switch existingVal := existingVal.(type) { + case string: + t.Extras[key] = []string{existingVal, val} + case []string: + t.Extras[key] = append(existingVal, val) + case int: + t.Extras[key], _ = strconv.Atoi(val) + case bool: + t.Extras[key] = (val == "true" || val == "t") + } + } else { + switch key { + case "minimum": + t.Extras[key], _ = strconv.Atoi(val) + default: + var x any + if val == "true" { + x = true + } else if val == "false" { + x = false + } else { + x = val + } + t.Extras[key] = x + } + } +} + +func requiredFromJSONTags(tags []string, val *bool) { + if ignoredByJSONTags(tags) { + return + } + + for _, tag := range tags[1:] { + if tag == "omitempty" { + *val = false + return + } + } + *val = true +} + +func requiredFromJSONSchemaTags(tags []string, val *bool) { + if ignoredByJSONSchemaTags(tags) { + return + } + for _, tag := range tags { + if tag == "required" { + *val = true + } + } +} + +func nullableFromJSONSchemaTags(tags []string) bool { + if ignoredByJSONSchemaTags(tags) { + return false + } + for _, tag := range tags { + if tag == "nullable" { + return true + } + } + return false +} + +func ignoredByJSONTags(tags []string) bool { + return tags[0] == "-" +} + +func ignoredByJSONSchemaTags(tags []string) bool { + return tags[0] == "-" +} + +func inlinedByJSONTags(tags []string) bool { + for _, tag := range tags[1:] { + if tag == "inline" { + return true + } + } + return false +} + +// toJSONNumber converts string to *json.Number. +// It'll aso return whether the number is valid. +func toJSONNumber(s string) (json.Number, bool) { + num := json.Number(s) + if _, err := num.Int64(); err == nil { + return num, true + } + if _, err := num.Float64(); err == nil { + return num, true + } + return json.Number(""), false +} + +func parseUint(num string) *uint64 { + val, err := strconv.ParseUint(num, 10, 64) + if err != nil { + return nil + } + return &val +} + +func (r *Reflector) fieldNameTag() string { + if r.FieldNameTag != "" { + return r.FieldNameTag + } + return "json" +} + +func (r *Reflector) reflectFieldName(f reflect.StructField) (string, bool, bool, bool) { + jsonTagString := f.Tag.Get(r.fieldNameTag()) + jsonTags := strings.Split(jsonTagString, ",") + + if ignoredByJSONTags(jsonTags) { + return "", false, false, false + } + + schemaTags := strings.Split(f.Tag.Get("jsonschema"), ",") + if ignoredByJSONSchemaTags(schemaTags) { + return "", false, false, false + } + + var required bool + if !r.RequiredFromJSONSchemaTags { + requiredFromJSONTags(jsonTags, &required) + } + requiredFromJSONSchemaTags(schemaTags, &required) + + nullable := nullableFromJSONSchemaTags(schemaTags) + + if f.Anonymous && jsonTags[0] == "" { + // As per JSON Marshal rules, anonymous structs are inherited + if f.Type.Kind() == reflect.Struct { + return "", true, false, false + } + + // As per JSON Marshal rules, anonymous pointer to structs are inherited + if f.Type.Kind() == reflect.Ptr && f.Type.Elem().Kind() == reflect.Struct { + return "", true, false, false + } + } + + // As per JSON Marshal rules, inline nested structs that have `inline` tag. + if inlinedByJSONTags(jsonTags) { + return "", true, false, false + } + + // Try to determine the name from the different combos + name := f.Name + if jsonTags[0] != "" { + name = jsonTags[0] + } + if !f.Anonymous && f.PkgPath != "" { + // field not anonymous and not export has no export name + name = "" + } else if r.KeyNamer != nil { + name = r.KeyNamer(name) + } + + return name, false, required, nullable +} + +// UnmarshalJSON is used to parse a schema object or boolean. +func (t *Schema) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, []byte("true")) { + *t = *TrueSchema + return nil + } else if bytes.Equal(data, []byte("false")) { + *t = *FalseSchema + return nil + } + type SchemaAlt Schema + aux := &struct { + *SchemaAlt + }{ + SchemaAlt: (*SchemaAlt)(t), + } + return json.Unmarshal(data, aux) +} + +// MarshalJSON is used to serialize a schema object or boolean. +func (t *Schema) MarshalJSON() ([]byte, error) { + if t.boolean != nil { + if *t.boolean { + return []byte("true"), nil + } + return []byte("false"), nil + } + if reflect.DeepEqual(&Schema{}, t) { + // Don't bother returning empty schemas + return []byte("true"), nil + } + type SchemaAlt Schema + b, err := json.Marshal((*SchemaAlt)(t)) + if err != nil { + return nil, err + } + if len(t.Extras) == 0 { + return b, nil + } + m, err := json.Marshal(t.Extras) + if err != nil { + return nil, err + } + if len(b) == 2 { + return m, nil + } + b[len(b)-1] = ',' + return append(b, m[1:]...), nil +} + +func (r *Reflector) typeName(t reflect.Type) string { + if r.Namer != nil { + if name := r.Namer(t); name != "" { + return name + } + } + return t.Name() +} + +// Split on commas that are not preceded by `\`. +// This way, we prevent splitting regexes +func splitOnUnescapedCommas(tagString string) []string { + ret := make([]string, 0) + separated := strings.Split(tagString, ",") + ret = append(ret, separated[0]) + i := 0 + for _, nextTag := range separated[1:] { + if len(ret[i]) == 0 { + ret = append(ret, nextTag) + i++ + continue + } + + if ret[i][len(ret[i])-1] == '\\' { + ret[i] = ret[i][:len(ret[i])-1] + "," + nextTag + } else { + ret = append(ret, nextTag) + i++ + } + } + + return ret +} + +func fullyQualifiedTypeName(t reflect.Type) string { + return t.PkgPath() + "." + t.Name() +} diff --git a/vendor/github.com/invopop/jsonschema/reflect_comments.go b/vendor/github.com/invopop/jsonschema/reflect_comments.go new file mode 100644 index 0000000000..ff374c75c8 --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/reflect_comments.go @@ -0,0 +1,146 @@ +package jsonschema + +import ( + "fmt" + "io/fs" + gopath "path" + "path/filepath" + "reflect" + "strings" + + "go/ast" + "go/doc" + "go/parser" + "go/token" +) + +type commentOptions struct { + fullObjectText bool // use the first sentence only? +} + +// CommentOption allows for special configuration options when preparing Go +// source files for comment extraction. +type CommentOption func(*commentOptions) + +// WithFullComment will configure the comment extraction to process to use an +// object type's full comment text instead of just the synopsis. +func WithFullComment() CommentOption { + return func(o *commentOptions) { + o.fullObjectText = true + } +} + +// AddGoComments will update the reflectors comment map with all the comments +// found in the provided source directories including sub-directories, in order to +// generate a dictionary of comments associated with Types and Fields. The results +// will be added to the `Reflect.CommentMap` ready to use with Schema "description" +// fields. +// +// The `go/parser` library is used to extract all the comments and unfortunately doesn't +// have a built-in way to determine the fully qualified name of a package. The `base` +// parameter, the URL used to import that package, is thus required to be able to match +// reflected types. +// +// When parsing type comments, by default we use the `go/doc`'s Synopsis method to extract +// the first phrase only. Field comments, which tend to be much shorter, will include everything. +// This behavior can be changed by using the `WithFullComment` option. +func (r *Reflector) AddGoComments(base, path string, opts ...CommentOption) error { + if r.CommentMap == nil { + r.CommentMap = make(map[string]string) + } + co := new(commentOptions) + for _, opt := range opts { + opt(co) + } + + return r.extractGoComments(base, path, r.CommentMap, co) +} + +func (r *Reflector) extractGoComments(base, path string, commentMap map[string]string, opts *commentOptions) error { + fset := token.NewFileSet() + dict := make(map[string][]*ast.Package) + err := filepath.Walk(path, func(path string, info fs.FileInfo, err error) error { + if err != nil { + return err + } + if info.IsDir() { + d, err := parser.ParseDir(fset, path, nil, parser.ParseComments) + if err != nil { + return err + } + for _, v := range d { + // paths may have multiple packages, like for tests + k := gopath.Join(base, path) + dict[k] = append(dict[k], v) + } + } + return nil + }) + if err != nil { + return err + } + + for pkg, p := range dict { + for _, f := range p { + gtxt := "" + typ := "" + ast.Inspect(f, func(n ast.Node) bool { + switch x := n.(type) { + case *ast.TypeSpec: + typ = x.Name.String() + if !ast.IsExported(typ) { + typ = "" + } else { + txt := x.Doc.Text() + if txt == "" && gtxt != "" { + txt = gtxt + gtxt = "" + } + if !opts.fullObjectText { + txt = doc.Synopsis(txt) + } + commentMap[fmt.Sprintf("%s.%s", pkg, typ)] = strings.TrimSpace(txt) + } + case *ast.Field: + txt := x.Doc.Text() + if txt == "" { + txt = x.Comment.Text() + } + if typ != "" && txt != "" { + for _, n := range x.Names { + if ast.IsExported(n.String()) { + k := fmt.Sprintf("%s.%s.%s", pkg, typ, n) + commentMap[k] = strings.TrimSpace(txt) + } + } + } + case *ast.GenDecl: + // remember for the next type + gtxt = x.Doc.Text() + } + return true + }) + } + } + + return nil +} + +func (r *Reflector) lookupComment(t reflect.Type, name string) string { + if r.LookupComment != nil { + if comment := r.LookupComment(t, name); comment != "" { + return comment + } + } + + if r.CommentMap == nil { + return "" + } + + n := fullyQualifiedTypeName(t) + if name != "" { + n = n + "." + name + } + + return r.CommentMap[n] +} diff --git a/vendor/github.com/invopop/jsonschema/schema.go b/vendor/github.com/invopop/jsonschema/schema.go new file mode 100644 index 0000000000..2d914b8c83 --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/schema.go @@ -0,0 +1,94 @@ +package jsonschema + +import ( + "encoding/json" + + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +// Version is the JSON Schema version. +var Version = "https://json-schema.org/draft/2020-12/schema" + +// Schema represents a JSON Schema object type. +// RFC draft-bhutton-json-schema-00 section 4.3 +type Schema struct { + // RFC draft-bhutton-json-schema-00 + Version string `json:"$schema,omitempty"` // section 8.1.1 + ID ID `json:"$id,omitempty"` // section 8.2.1 + Anchor string `json:"$anchor,omitempty"` // section 8.2.2 + Ref string `json:"$ref,omitempty"` // section 8.2.3.1 + DynamicRef string `json:"$dynamicRef,omitempty"` // section 8.2.3.2 + Definitions Definitions `json:"$defs,omitempty"` // section 8.2.4 + Comments string `json:"$comment,omitempty"` // section 8.3 + // RFC draft-bhutton-json-schema-00 section 10.2.1 (Sub-schemas with logic) + AllOf []*Schema `json:"allOf,omitempty"` // section 10.2.1.1 + AnyOf []*Schema `json:"anyOf,omitempty"` // section 10.2.1.2 + OneOf []*Schema `json:"oneOf,omitempty"` // section 10.2.1.3 + Not *Schema `json:"not,omitempty"` // section 10.2.1.4 + // RFC draft-bhutton-json-schema-00 section 10.2.2 (Apply sub-schemas conditionally) + If *Schema `json:"if,omitempty"` // section 10.2.2.1 + Then *Schema `json:"then,omitempty"` // section 10.2.2.2 + Else *Schema `json:"else,omitempty"` // section 10.2.2.3 + DependentSchemas map[string]*Schema `json:"dependentSchemas,omitempty"` // section 10.2.2.4 + // RFC draft-bhutton-json-schema-00 section 10.3.1 (arrays) + PrefixItems []*Schema `json:"prefixItems,omitempty"` // section 10.3.1.1 + Items *Schema `json:"items,omitempty"` // section 10.3.1.2 (replaces additionalItems) + Contains *Schema `json:"contains,omitempty"` // section 10.3.1.3 + // RFC draft-bhutton-json-schema-00 section 10.3.2 (sub-schemas) + Properties *orderedmap.OrderedMap[string, *Schema] `json:"properties,omitempty"` // section 10.3.2.1 + PatternProperties map[string]*Schema `json:"patternProperties,omitempty"` // section 10.3.2.2 + AdditionalProperties *Schema `json:"additionalProperties,omitempty"` // section 10.3.2.3 + PropertyNames *Schema `json:"propertyNames,omitempty"` // section 10.3.2.4 + // RFC draft-bhutton-json-schema-validation-00, section 6 + Type string `json:"type,omitempty"` // section 6.1.1 + Enum []any `json:"enum,omitempty"` // section 6.1.2 + Const any `json:"const,omitempty"` // section 6.1.3 + MultipleOf json.Number `json:"multipleOf,omitempty"` // section 6.2.1 + Maximum json.Number `json:"maximum,omitempty"` // section 6.2.2 + ExclusiveMaximum json.Number `json:"exclusiveMaximum,omitempty"` // section 6.2.3 + Minimum json.Number `json:"minimum,omitempty"` // section 6.2.4 + ExclusiveMinimum json.Number `json:"exclusiveMinimum,omitempty"` // section 6.2.5 + MaxLength *uint64 `json:"maxLength,omitempty"` // section 6.3.1 + MinLength *uint64 `json:"minLength,omitempty"` // section 6.3.2 + Pattern string `json:"pattern,omitempty"` // section 6.3.3 + MaxItems *uint64 `json:"maxItems,omitempty"` // section 6.4.1 + MinItems *uint64 `json:"minItems,omitempty"` // section 6.4.2 + UniqueItems bool `json:"uniqueItems,omitempty"` // section 6.4.3 + MaxContains *uint64 `json:"maxContains,omitempty"` // section 6.4.4 + MinContains *uint64 `json:"minContains,omitempty"` // section 6.4.5 + MaxProperties *uint64 `json:"maxProperties,omitempty"` // section 6.5.1 + MinProperties *uint64 `json:"minProperties,omitempty"` // section 6.5.2 + Required []string `json:"required,omitempty"` // section 6.5.3 + DependentRequired map[string][]string `json:"dependentRequired,omitempty"` // section 6.5.4 + // RFC draft-bhutton-json-schema-validation-00, section 7 + Format string `json:"format,omitempty"` + // RFC draft-bhutton-json-schema-validation-00, section 8 + ContentEncoding string `json:"contentEncoding,omitempty"` // section 8.3 + ContentMediaType string `json:"contentMediaType,omitempty"` // section 8.4 + ContentSchema *Schema `json:"contentSchema,omitempty"` // section 8.5 + // RFC draft-bhutton-json-schema-validation-00, section 9 + Title string `json:"title,omitempty"` // section 9.1 + Description string `json:"description,omitempty"` // section 9.1 + Default any `json:"default,omitempty"` // section 9.2 + Deprecated bool `json:"deprecated,omitempty"` // section 9.3 + ReadOnly bool `json:"readOnly,omitempty"` // section 9.4 + WriteOnly bool `json:"writeOnly,omitempty"` // section 9.4 + Examples []any `json:"examples,omitempty"` // section 9.5 + + Extras map[string]any `json:"-"` + + // Special boolean representation of the Schema - section 4.3.2 + boolean *bool +} + +var ( + // TrueSchema defines a schema with a true value + TrueSchema = &Schema{boolean: &[]bool{true}[0]} + // FalseSchema defines a schema with a false value + FalseSchema = &Schema{boolean: &[]bool{false}[0]} +) + +// Definitions hold schema definitions. +// http://json-schema.org/latest/json-schema-validation.html#rfc.section.5.26 +// RFC draft-wright-json-schema-validation-00, section 5.26 +type Definitions map[string]*Schema diff --git a/vendor/github.com/invopop/jsonschema/utils.go b/vendor/github.com/invopop/jsonschema/utils.go new file mode 100644 index 0000000000..ed8edf7411 --- /dev/null +++ b/vendor/github.com/invopop/jsonschema/utils.go @@ -0,0 +1,26 @@ +package jsonschema + +import ( + "regexp" + "strings" + + orderedmap "github.com/wk8/go-ordered-map/v2" +) + +var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)") +var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])") + +// ToSnakeCase converts the provided string into snake case using dashes. +// This is useful for Schema IDs and definitions to be coherent with +// common JSON Schema examples. +func ToSnakeCase(str string) string { + snake := matchFirstCap.ReplaceAllString(str, "${1}-${2}") + snake = matchAllCap.ReplaceAllString(snake, "${1}-${2}") + return strings.ToLower(snake) +} + +// NewProperties is a helper method to instantiate a new properties ordered +// map. +func NewProperties() *orderedmap.OrderedMap[string, *Schema] { + return orderedmap.New[string, *Schema]() +} diff --git a/vendor/github.com/mailru/easyjson/LICENSE b/vendor/github.com/mailru/easyjson/LICENSE new file mode 100644 index 0000000000..fbff658f70 --- /dev/null +++ b/vendor/github.com/mailru/easyjson/LICENSE @@ -0,0 +1,7 @@ +Copyright (c) 2016 Mail.Ru Group + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/vendor/github.com/mailru/easyjson/buffer/pool.go b/vendor/github.com/mailru/easyjson/buffer/pool.go new file mode 100644 index 0000000000..598a54af9d --- /dev/null +++ b/vendor/github.com/mailru/easyjson/buffer/pool.go @@ -0,0 +1,278 @@ +// Package buffer implements a buffer for serialization, consisting of a chain of []byte-s to +// reduce copying and to allow reuse of individual chunks. +package buffer + +import ( + "io" + "net" + "sync" +) + +// PoolConfig contains configuration for the allocation and reuse strategy. +type PoolConfig struct { + StartSize int // Minimum chunk size that is allocated. + PooledSize int // Minimum chunk size that is reused, reusing chunks too small will result in overhead. + MaxSize int // Maximum chunk size that will be allocated. +} + +var config = PoolConfig{ + StartSize: 128, + PooledSize: 512, + MaxSize: 32768, +} + +// Reuse pool: chunk size -> pool. +var buffers = map[int]*sync.Pool{} + +func initBuffers() { + for l := config.PooledSize; l <= config.MaxSize; l *= 2 { + buffers[l] = new(sync.Pool) + } +} + +func init() { + initBuffers() +} + +// Init sets up a non-default pooling and allocation strategy. Should be run before serialization is done. +func Init(cfg PoolConfig) { + config = cfg + initBuffers() +} + +// putBuf puts a chunk to reuse pool if it can be reused. +func putBuf(buf []byte) { + size := cap(buf) + if size < config.PooledSize { + return + } + if c := buffers[size]; c != nil { + c.Put(buf[:0]) + } +} + +// getBuf gets a chunk from reuse pool or creates a new one if reuse failed. +func getBuf(size int) []byte { + if size >= config.PooledSize { + if c := buffers[size]; c != nil { + v := c.Get() + if v != nil { + return v.([]byte) + } + } + } + return make([]byte, 0, size) +} + +// Buffer is a buffer optimized for serialization without extra copying. +type Buffer struct { + + // Buf is the current chunk that can be used for serialization. + Buf []byte + + toPool []byte + bufs [][]byte +} + +// EnsureSpace makes sure that the current chunk contains at least s free bytes, +// possibly creating a new chunk. +func (b *Buffer) EnsureSpace(s int) { + if cap(b.Buf)-len(b.Buf) < s { + b.ensureSpaceSlow(s) + } +} + +func (b *Buffer) ensureSpaceSlow(s int) { + l := len(b.Buf) + if l > 0 { + if cap(b.toPool) != cap(b.Buf) { + // Chunk was reallocated, toPool can be pooled. + putBuf(b.toPool) + } + if cap(b.bufs) == 0 { + b.bufs = make([][]byte, 0, 8) + } + b.bufs = append(b.bufs, b.Buf) + l = cap(b.toPool) * 2 + } else { + l = config.StartSize + } + + if l > config.MaxSize { + l = config.MaxSize + } + b.Buf = getBuf(l) + b.toPool = b.Buf +} + +// AppendByte appends a single byte to buffer. +func (b *Buffer) AppendByte(data byte) { + b.EnsureSpace(1) + b.Buf = append(b.Buf, data) +} + +// AppendBytes appends a byte slice to buffer. +func (b *Buffer) AppendBytes(data []byte) { + if len(data) <= cap(b.Buf)-len(b.Buf) { + b.Buf = append(b.Buf, data...) // fast path + } else { + b.appendBytesSlow(data) + } +} + +func (b *Buffer) appendBytesSlow(data []byte) { + for len(data) > 0 { + b.EnsureSpace(1) + + sz := cap(b.Buf) - len(b.Buf) + if sz > len(data) { + sz = len(data) + } + + b.Buf = append(b.Buf, data[:sz]...) + data = data[sz:] + } +} + +// AppendString appends a string to buffer. +func (b *Buffer) AppendString(data string) { + if len(data) <= cap(b.Buf)-len(b.Buf) { + b.Buf = append(b.Buf, data...) // fast path + } else { + b.appendStringSlow(data) + } +} + +func (b *Buffer) appendStringSlow(data string) { + for len(data) > 0 { + b.EnsureSpace(1) + + sz := cap(b.Buf) - len(b.Buf) + if sz > len(data) { + sz = len(data) + } + + b.Buf = append(b.Buf, data[:sz]...) + data = data[sz:] + } +} + +// Size computes the size of a buffer by adding sizes of every chunk. +func (b *Buffer) Size() int { + size := len(b.Buf) + for _, buf := range b.bufs { + size += len(buf) + } + return size +} + +// DumpTo outputs the contents of a buffer to a writer and resets the buffer. +func (b *Buffer) DumpTo(w io.Writer) (written int, err error) { + bufs := net.Buffers(b.bufs) + if len(b.Buf) > 0 { + bufs = append(bufs, b.Buf) + } + n, err := bufs.WriteTo(w) + + for _, buf := range b.bufs { + putBuf(buf) + } + putBuf(b.toPool) + + b.bufs = nil + b.Buf = nil + b.toPool = nil + + return int(n), err +} + +// BuildBytes creates a single byte slice with all the contents of the buffer. Data is +// copied if it does not fit in a single chunk. You can optionally provide one byte +// slice as argument that it will try to reuse. +func (b *Buffer) BuildBytes(reuse ...[]byte) []byte { + if len(b.bufs) == 0 { + ret := b.Buf + b.toPool = nil + b.Buf = nil + return ret + } + + var ret []byte + size := b.Size() + + // If we got a buffer as argument and it is big enough, reuse it. + if len(reuse) == 1 && cap(reuse[0]) >= size { + ret = reuse[0][:0] + } else { + ret = make([]byte, 0, size) + } + for _, buf := range b.bufs { + ret = append(ret, buf...) + putBuf(buf) + } + + ret = append(ret, b.Buf...) + putBuf(b.toPool) + + b.bufs = nil + b.toPool = nil + b.Buf = nil + + return ret +} + +type readCloser struct { + offset int + bufs [][]byte +} + +func (r *readCloser) Read(p []byte) (n int, err error) { + for _, buf := range r.bufs { + // Copy as much as we can. + x := copy(p[n:], buf[r.offset:]) + n += x // Increment how much we filled. + + // Did we empty the whole buffer? + if r.offset+x == len(buf) { + // On to the next buffer. + r.offset = 0 + r.bufs = r.bufs[1:] + + // We can release this buffer. + putBuf(buf) + } else { + r.offset += x + } + + if n == len(p) { + break + } + } + // No buffers left or nothing read? + if len(r.bufs) == 0 { + err = io.EOF + } + return +} + +func (r *readCloser) Close() error { + // Release all remaining buffers. + for _, buf := range r.bufs { + putBuf(buf) + } + // In case Close gets called multiple times. + r.bufs = nil + + return nil +} + +// ReadCloser creates an io.ReadCloser with all the contents of the buffer. +func (b *Buffer) ReadCloser() io.ReadCloser { + ret := &readCloser{0, append(b.bufs, b.Buf)} + + b.bufs = nil + b.toPool = nil + b.Buf = nil + + return ret +} diff --git a/vendor/github.com/mailru/easyjson/jwriter/writer.go b/vendor/github.com/mailru/easyjson/jwriter/writer.go new file mode 100644 index 0000000000..2c5b20105b --- /dev/null +++ b/vendor/github.com/mailru/easyjson/jwriter/writer.go @@ -0,0 +1,405 @@ +// Package jwriter contains a JSON writer. +package jwriter + +import ( + "io" + "strconv" + "unicode/utf8" + + "github.com/mailru/easyjson/buffer" +) + +// Flags describe various encoding options. The behavior may be actually implemented in the encoder, but +// Flags field in Writer is used to set and pass them around. +type Flags int + +const ( + NilMapAsEmpty Flags = 1 << iota // Encode nil map as '{}' rather than 'null'. + NilSliceAsEmpty // Encode nil slice as '[]' rather than 'null'. +) + +// Writer is a JSON writer. +type Writer struct { + Flags Flags + + Error error + Buffer buffer.Buffer + NoEscapeHTML bool +} + +// Size returns the size of the data that was written out. +func (w *Writer) Size() int { + return w.Buffer.Size() +} + +// DumpTo outputs the data to given io.Writer, resetting the buffer. +func (w *Writer) DumpTo(out io.Writer) (written int, err error) { + return w.Buffer.DumpTo(out) +} + +// BuildBytes returns writer data as a single byte slice. You can optionally provide one byte slice +// as argument that it will try to reuse. +func (w *Writer) BuildBytes(reuse ...[]byte) ([]byte, error) { + if w.Error != nil { + return nil, w.Error + } + + return w.Buffer.BuildBytes(reuse...), nil +} + +// ReadCloser returns an io.ReadCloser that can be used to read the data. +// ReadCloser also resets the buffer. +func (w *Writer) ReadCloser() (io.ReadCloser, error) { + if w.Error != nil { + return nil, w.Error + } + + return w.Buffer.ReadCloser(), nil +} + +// RawByte appends raw binary data to the buffer. +func (w *Writer) RawByte(c byte) { + w.Buffer.AppendByte(c) +} + +// RawByte appends raw binary data to the buffer. +func (w *Writer) RawString(s string) { + w.Buffer.AppendString(s) +} + +// Raw appends raw binary data to the buffer or sets the error if it is given. Useful for +// calling with results of MarshalJSON-like functions. +func (w *Writer) Raw(data []byte, err error) { + switch { + case w.Error != nil: + return + case err != nil: + w.Error = err + case len(data) > 0: + w.Buffer.AppendBytes(data) + default: + w.RawString("null") + } +} + +// RawText encloses raw binary data in quotes and appends in to the buffer. +// Useful for calling with results of MarshalText-like functions. +func (w *Writer) RawText(data []byte, err error) { + switch { + case w.Error != nil: + return + case err != nil: + w.Error = err + case len(data) > 0: + w.String(string(data)) + default: + w.RawString("null") + } +} + +// Base64Bytes appends data to the buffer after base64 encoding it +func (w *Writer) Base64Bytes(data []byte) { + if data == nil { + w.Buffer.AppendString("null") + return + } + w.Buffer.AppendByte('"') + w.base64(data) + w.Buffer.AppendByte('"') +} + +func (w *Writer) Uint8(n uint8) { + w.Buffer.EnsureSpace(3) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +func (w *Writer) Uint16(n uint16) { + w.Buffer.EnsureSpace(5) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +func (w *Writer) Uint32(n uint32) { + w.Buffer.EnsureSpace(10) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +func (w *Writer) Uint(n uint) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) +} + +func (w *Writer) Uint64(n uint64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, n, 10) +} + +func (w *Writer) Int8(n int8) { + w.Buffer.EnsureSpace(4) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +func (w *Writer) Int16(n int16) { + w.Buffer.EnsureSpace(6) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +func (w *Writer) Int32(n int32) { + w.Buffer.EnsureSpace(11) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +func (w *Writer) Int(n int) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) +} + +func (w *Writer) Int64(n int64) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, n, 10) +} + +func (w *Writer) Uint8Str(n uint8) { + w.Buffer.EnsureSpace(3) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Uint16Str(n uint16) { + w.Buffer.EnsureSpace(5) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Uint32Str(n uint32) { + w.Buffer.EnsureSpace(10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) UintStr(n uint) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Uint64Str(n uint64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, n, 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) UintptrStr(n uintptr) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendUint(w.Buffer.Buf, uint64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Int8Str(n int8) { + w.Buffer.EnsureSpace(4) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Int16Str(n int16) { + w.Buffer.EnsureSpace(6) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Int32Str(n int32) { + w.Buffer.EnsureSpace(11) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) IntStr(n int) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, int64(n), 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Int64Str(n int64) { + w.Buffer.EnsureSpace(21) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendInt(w.Buffer.Buf, n, 10) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Float32(n float32) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 32) +} + +func (w *Writer) Float32Str(n float32) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 32) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Float64(n float64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, n, 'g', -1, 64) +} + +func (w *Writer) Float64Str(n float64) { + w.Buffer.EnsureSpace(20) + w.Buffer.Buf = append(w.Buffer.Buf, '"') + w.Buffer.Buf = strconv.AppendFloat(w.Buffer.Buf, float64(n), 'g', -1, 64) + w.Buffer.Buf = append(w.Buffer.Buf, '"') +} + +func (w *Writer) Bool(v bool) { + w.Buffer.EnsureSpace(5) + if v { + w.Buffer.Buf = append(w.Buffer.Buf, "true"...) + } else { + w.Buffer.Buf = append(w.Buffer.Buf, "false"...) + } +} + +const chars = "0123456789abcdef" + +func getTable(falseValues ...int) [128]bool { + table := [128]bool{} + + for i := 0; i < 128; i++ { + table[i] = true + } + + for _, v := range falseValues { + table[v] = false + } + + return table +} + +var ( + htmlEscapeTable = getTable(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, '"', '&', '<', '>', '\\') + htmlNoEscapeTable = getTable(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, '"', '\\') +) + +func (w *Writer) String(s string) { + w.Buffer.AppendByte('"') + + // Portions of the string that contain no escapes are appended as + // byte slices. + + p := 0 // last non-escape symbol + + escapeTable := &htmlEscapeTable + if w.NoEscapeHTML { + escapeTable = &htmlNoEscapeTable + } + + for i := 0; i < len(s); { + c := s[i] + + if c < utf8.RuneSelf { + if escapeTable[c] { + // single-width character, no escaping is required + i++ + continue + } + + w.Buffer.AppendString(s[p:i]) + switch c { + case '\t': + w.Buffer.AppendString(`\t`) + case '\r': + w.Buffer.AppendString(`\r`) + case '\n': + w.Buffer.AppendString(`\n`) + case '\\': + w.Buffer.AppendString(`\\`) + case '"': + w.Buffer.AppendString(`\"`) + default: + w.Buffer.AppendString(`\u00`) + w.Buffer.AppendByte(chars[c>>4]) + w.Buffer.AppendByte(chars[c&0xf]) + } + + i++ + p = i + continue + } + + // broken utf + runeValue, runeWidth := utf8.DecodeRuneInString(s[i:]) + if runeValue == utf8.RuneError && runeWidth == 1 { + w.Buffer.AppendString(s[p:i]) + w.Buffer.AppendString(`\ufffd`) + i++ + p = i + continue + } + + // jsonp stuff - tab separator and line separator + if runeValue == '\u2028' || runeValue == '\u2029' { + w.Buffer.AppendString(s[p:i]) + w.Buffer.AppendString(`\u202`) + w.Buffer.AppendByte(chars[runeValue&0xf]) + i += runeWidth + p = i + continue + } + i += runeWidth + } + w.Buffer.AppendString(s[p:]) + w.Buffer.AppendByte('"') +} + +const encode = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/" +const padChar = '=' + +func (w *Writer) base64(in []byte) { + + if len(in) == 0 { + return + } + + w.Buffer.EnsureSpace(((len(in)-1)/3 + 1) * 4) + + si := 0 + n := (len(in) / 3) * 3 + + for si < n { + // Convert 3x 8bit source bytes into 4 bytes + val := uint(in[si+0])<<16 | uint(in[si+1])<<8 | uint(in[si+2]) + + w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>18&0x3F], encode[val>>12&0x3F], encode[val>>6&0x3F], encode[val&0x3F]) + + si += 3 + } + + remain := len(in) - si + if remain == 0 { + return + } + + // Add the remaining small block + val := uint(in[si+0]) << 16 + if remain == 2 { + val |= uint(in[si+1]) << 8 + } + + w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>18&0x3F], encode[val>>12&0x3F]) + + switch remain { + case 2: + w.Buffer.Buf = append(w.Buffer.Buf, encode[val>>6&0x3F], byte(padChar)) + case 1: + w.Buffer.Buf = append(w.Buffer.Buf, byte(padChar), byte(padChar)) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/LICENSE b/vendor/github.com/mark3labs/mcp-go/LICENSE new file mode 100644 index 0000000000..3d48435454 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Anthropic, PBC + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/consts.go b/vendor/github.com/mark3labs/mcp-go/mcp/consts.go new file mode 100644 index 0000000000..66eb3803bc --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/consts.go @@ -0,0 +1,9 @@ +package mcp + +const ( + ContentTypeText = "text" + ContentTypeImage = "image" + ContentTypeAudio = "audio" + ContentTypeLink = "resource_link" + ContentTypeResource = "resource" +) diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/errors.go b/vendor/github.com/mark3labs/mcp-go/mcp/errors.go new file mode 100644 index 0000000000..01888bf5b6 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/errors.go @@ -0,0 +1,25 @@ +package mcp + +import "fmt" + +// UnsupportedProtocolVersionError is returned when the server responds with +// a protocol version that the client doesn't support. +type UnsupportedProtocolVersionError struct { + Version string +} + +func (e UnsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version: %q", e.Version) +} + +// Is implements the errors.Is interface for better error handling +func (e UnsupportedProtocolVersionError) Is(target error) bool { + _, ok := target.(UnsupportedProtocolVersionError) + return ok +} + +// IsUnsupportedProtocolVersion checks if an error is an UnsupportedProtocolVersionError +func IsUnsupportedProtocolVersion(err error) bool { + _, ok := err.(UnsupportedProtocolVersionError) + return ok +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go new file mode 100644 index 0000000000..9b0b48ed23 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/prompts.go @@ -0,0 +1,176 @@ +package mcp + +import "net/http" + +/* Prompts */ + +// ListPromptsRequest is sent from the client to request a list of prompts and +// prompt templates the server has. +type ListPromptsRequest struct { + PaginatedRequest + Header http.Header `json:"-"` +} + +// ListPromptsResult is the server's response to a prompts/list request from +// the client. +type ListPromptsResult struct { + PaginatedResult + Prompts []Prompt `json:"prompts"` +} + +// GetPromptRequest is used by the client to get a prompt provided by the +// server. +type GetPromptRequest struct { + Request + Params GetPromptParams `json:"params"` + Header http.Header `json:"-"` +} + +type GetPromptParams struct { + // The name of the prompt or prompt template. + Name string `json:"name"` + // Arguments to use for templating the prompt. + Arguments map[string]string `json:"arguments,omitempty"` +} + +// GetPromptResult is the server's response to a prompts/get request from the +// client. +type GetPromptResult struct { + Result + // An optional description for the prompt. + Description string `json:"description,omitempty"` + Messages []PromptMessage `json:"messages"` +} + +// Prompt represents a prompt or prompt template that the server offers. +// If Arguments is non-nil and non-empty, this indicates the prompt is a template +// that requires argument values to be provided when calling prompts/get. +// If Arguments is nil or empty, this is a static prompt that takes no arguments. +type Prompt struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // The name of the prompt or prompt template. + Name string `json:"name"` + // An optional description of what this prompt provides + Description string `json:"description,omitempty"` + // A list of arguments to use for templating the prompt. + // The presence of arguments indicates this is a template prompt. + Arguments []PromptArgument `json:"arguments,omitempty"` +} + +// GetName returns the name of the prompt. +func (p Prompt) GetName() string { + return p.Name +} + +// PromptArgument describes an argument that a prompt template can accept. +// When a prompt includes arguments, clients must provide values for all +// required arguments when making a prompts/get request. +type PromptArgument struct { + // The name of the argument. + Name string `json:"name"` + // A human-readable description of the argument. + Description string `json:"description,omitempty"` + // Whether this argument must be provided. + // If true, clients must include this argument when calling prompts/get. + Required bool `json:"required,omitempty"` +} + +// Role represents the sender or recipient of messages and data in a +// conversation. +type Role string + +const ( + RoleUser Role = "user" + RoleAssistant Role = "assistant" +) + +// PromptMessage describes a message returned as part of a prompt. +// +// This is similar to `SamplingMessage`, but also supports the embedding of +// resources from the MCP server. +type PromptMessage struct { + Role Role `json:"role"` + Content Content `json:"content"` // Can be TextContent, ImageContent, AudioContent or EmbeddedResource +} + +// PromptListChangedNotification is an optional notification from the server +// to the client, informing it that the list of prompts it offers has changed. This +// may be issued by servers without any previous subscription from the client. +type PromptListChangedNotification struct { + Notification +} + +// PromptOption is a function that configures a Prompt. +// It provides a flexible way to set various properties of a Prompt using the functional options pattern. +type PromptOption func(*Prompt) + +// ArgumentOption is a function that configures a PromptArgument. +// It allows for flexible configuration of prompt arguments using the functional options pattern. +type ArgumentOption func(*PromptArgument) + +// +// Core Prompt Functions +// + +// NewPrompt creates a new Prompt with the given name and options. +// The prompt will be configured based on the provided options. +// Options are applied in order, allowing for flexible prompt configuration. +func NewPrompt(name string, opts ...PromptOption) Prompt { + prompt := Prompt{ + Name: name, + } + + for _, opt := range opts { + opt(&prompt) + } + + return prompt +} + +// WithPromptDescription adds a description to the Prompt. +// The description should provide a clear, human-readable explanation of what the prompt does. +func WithPromptDescription(description string) PromptOption { + return func(p *Prompt) { + p.Description = description + } +} + +// WithArgument adds an argument to the prompt's argument list. +// The argument will be configured based on the provided options. +func WithArgument(name string, opts ...ArgumentOption) PromptOption { + return func(p *Prompt) { + arg := PromptArgument{ + Name: name, + } + + for _, opt := range opts { + opt(&arg) + } + + if p.Arguments == nil { + p.Arguments = make([]PromptArgument, 0) + } + p.Arguments = append(p.Arguments, arg) + } +} + +// +// Argument Options +// + +// ArgumentDescription adds a description to a prompt argument. +// The description should explain the purpose and expected values of the argument. +func ArgumentDescription(desc string) ArgumentOption { + return func(arg *PromptArgument) { + arg.Description = desc + } +} + +// RequiredArgument marks an argument as required in the prompt. +// Required arguments must be provided when getting the prompt. +func RequiredArgument() ArgumentOption { + return func(arg *PromptArgument) { + arg.Required = true + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/resources.go b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go new file mode 100644 index 0000000000..07a59a3223 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/resources.go @@ -0,0 +1,99 @@ +package mcp + +import "github.com/yosida95/uritemplate/v3" + +// ResourceOption is a function that configures a Resource. +// It provides a flexible way to set various properties of a Resource using the functional options pattern. +type ResourceOption func(*Resource) + +// NewResource creates a new Resource with the given URI, name and options. +// The resource will be configured based on the provided options. +// Options are applied in order, allowing for flexible resource configuration. +func NewResource(uri string, name string, opts ...ResourceOption) Resource { + resource := Resource{ + URI: uri, + Name: name, + } + + for _, opt := range opts { + opt(&resource) + } + + return resource +} + +// WithResourceDescription adds a description to the Resource. +// The description should provide a clear, human-readable explanation of what the resource represents. +func WithResourceDescription(description string) ResourceOption { + return func(r *Resource) { + r.Description = description + } +} + +// WithMIMEType sets the MIME type for the Resource. +// This should indicate the format of the resource's contents. +func WithMIMEType(mimeType string) ResourceOption { + return func(r *Resource) { + r.MIMEType = mimeType + } +} + +// WithAnnotations adds annotations to the Resource. +// Annotations can provide additional metadata about the resource's intended use. +func WithAnnotations(audience []Role, priority float64) ResourceOption { + return func(r *Resource) { + if r.Annotations == nil { + r.Annotations = &Annotations{} + } + r.Annotations.Audience = audience + r.Annotations.Priority = priority + } +} + +// ResourceTemplateOption is a function that configures a ResourceTemplate. +// It provides a flexible way to set various properties of a ResourceTemplate using the functional options pattern. +type ResourceTemplateOption func(*ResourceTemplate) + +// NewResourceTemplate creates a new ResourceTemplate with the given URI template, name and options. +// The template will be configured based on the provided options. +// Options are applied in order, allowing for flexible template configuration. +func NewResourceTemplate(uriTemplate string, name string, opts ...ResourceTemplateOption) ResourceTemplate { + template := ResourceTemplate{ + URITemplate: &URITemplate{Template: uritemplate.MustNew(uriTemplate)}, + Name: name, + } + + for _, opt := range opts { + opt(&template) + } + + return template +} + +// WithTemplateDescription adds a description to the ResourceTemplate. +// The description should provide a clear, human-readable explanation of what resources this template represents. +func WithTemplateDescription(description string) ResourceTemplateOption { + return func(t *ResourceTemplate) { + t.Description = description + } +} + +// WithTemplateMIMEType sets the MIME type for the ResourceTemplate. +// This should only be set if all resources matching this template will have the same type. +func WithTemplateMIMEType(mimeType string) ResourceTemplateOption { + return func(t *ResourceTemplate) { + t.MIMEType = mimeType + } +} + +// WithTemplateAnnotations adds annotations to the ResourceTemplate. +// Annotations can provide additional metadata about the template's intended use. +func WithTemplateAnnotations(audience []Role, priority float64) ResourceTemplateOption { + return func(t *ResourceTemplate) { + if t.Annotations == nil { + t.Annotations = &Annotations{} + } + t.Annotations.Audience = audience + t.Annotations.Priority = priority + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go new file mode 100644 index 0000000000..493e8c7787 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/tools.go @@ -0,0 +1,1277 @@ +package mcp + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "reflect" + "strconv" + + "github.com/invopop/jsonschema" +) + +var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSchema, not both") + +// ListToolsRequest is sent from the client to request a list of tools the +// server has. +type ListToolsRequest struct { + PaginatedRequest + Header http.Header `json:"-"` +} + +// ListToolsResult is the server's response to a tools/list request from the +// client. +type ListToolsResult struct { + PaginatedResult + Tools []Tool `json:"tools"` +} + +// CallToolResult is the server's response to a tool call. +// +// Any errors that originate from the tool SHOULD be reported inside the result +// object, with `isError` set to true, _not_ as an MCP protocol-level error +// response. Otherwise, the LLM would not be able to see that an error occurred +// and self-correct. +// +// However, any errors in _finding_ the tool, an error indicating that the +// server does not support tool calls, or any other exceptional conditions, +// should be reported as an MCP error response. +type CallToolResult struct { + Result + Content []Content `json:"content"` // Can be TextContent, ImageContent, AudioContent, or EmbeddedResource + // Structured content returned as a JSON object in the structuredContent field of a result. + // For backwards compatibility, a tool that returns structured content SHOULD also return + // functionally equivalent unstructured content. + StructuredContent any `json:"structuredContent,omitempty"` + // Whether the tool call ended in an error. + // + // If not set, this is assumed to be false (the call was successful). + IsError bool `json:"isError,omitempty"` +} + +// CallToolRequest is used by the client to invoke a tool provided by the server. +type CallToolRequest struct { + Request + Header http.Header `json:"-"` // HTTP headers from the original request + Params CallToolParams `json:"params"` +} + +type CallToolParams struct { + Name string `json:"name"` + Arguments any `json:"arguments,omitempty"` + Meta *Meta `json:"_meta,omitempty"` +} + +// GetArguments returns the Arguments as map[string]any for backward compatibility +// If Arguments is not a map, it returns an empty map +func (r CallToolRequest) GetArguments() map[string]any { + if args, ok := r.Params.Arguments.(map[string]any); ok { + return args + } + return nil +} + +// GetRawArguments returns the Arguments as-is without type conversion +// This allows users to access the raw arguments in any format +func (r CallToolRequest) GetRawArguments() any { + return r.Params.Arguments +} + +// BindArguments unmarshals the Arguments into the provided struct +// This is useful for working with strongly-typed arguments +func (r CallToolRequest) BindArguments(target any) error { + if target == nil || reflect.ValueOf(target).Kind() != reflect.Ptr { + return fmt.Errorf("target must be a non-nil pointer") + } + + // Fast-path: already raw JSON + if raw, ok := r.Params.Arguments.(json.RawMessage); ok { + return json.Unmarshal(raw, target) + } + + data, err := json.Marshal(r.Params.Arguments) + if err != nil { + return fmt.Errorf("failed to marshal arguments: %w", err) + } + + return json.Unmarshal(data, target) +} + +// GetString returns a string argument by key, or the default value if not found +func (r CallToolRequest) GetString(key string, defaultValue string) string { + args := r.GetArguments() + if val, ok := args[key]; ok { + if str, ok := val.(string); ok { + return str + } + } + return defaultValue +} + +// RequireString returns a string argument by key, or an error if not found or not a string +func (r CallToolRequest) RequireString(key string) (string, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + if str, ok := val.(string); ok { + return str, nil + } + return "", fmt.Errorf("argument %q is not a string", key) + } + return "", fmt.Errorf("required argument %q not found", key) +} + +// GetInt returns an int argument by key, or the default value if not found +func (r CallToolRequest) GetInt(key string, defaultValue int) int { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case int: + return v + case float64: + return int(v) + case string: + if i, err := strconv.Atoi(v); err == nil { + return i + } + } + } + return defaultValue +} + +// RequireInt returns an int argument by key, or an error if not found or not convertible to int +func (r CallToolRequest) RequireInt(key string) (int, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case int: + return v, nil + case float64: + return int(v), nil + case string: + if i, err := strconv.Atoi(v); err == nil { + return i, nil + } + return 0, fmt.Errorf("argument %q cannot be converted to int", key) + default: + return 0, fmt.Errorf("argument %q is not an int", key) + } + } + return 0, fmt.Errorf("required argument %q not found", key) +} + +// GetFloat returns a float64 argument by key, or the default value if not found +func (r CallToolRequest) GetFloat(key string, defaultValue float64) float64 { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case float64: + return v + case int: + return float64(v) + case string: + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f + } + } + } + return defaultValue +} + +// RequireFloat returns a float64 argument by key, or an error if not found or not convertible to float64 +func (r CallToolRequest) RequireFloat(key string) (float64, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case float64: + return v, nil + case int: + return float64(v), nil + case string: + if f, err := strconv.ParseFloat(v, 64); err == nil { + return f, nil + } + return 0, fmt.Errorf("argument %q cannot be converted to float64", key) + default: + return 0, fmt.Errorf("argument %q is not a float64", key) + } + } + return 0, fmt.Errorf("required argument %q not found", key) +} + +// GetBool returns a bool argument by key, or the default value if not found +func (r CallToolRequest) GetBool(key string, defaultValue bool) bool { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case bool: + return v + case string: + if b, err := strconv.ParseBool(v); err == nil { + return b + } + case int: + return v != 0 + case float64: + return v != 0 + } + } + return defaultValue +} + +// RequireBool returns a bool argument by key, or an error if not found or not convertible to bool +func (r CallToolRequest) RequireBool(key string) (bool, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case bool: + return v, nil + case string: + if b, err := strconv.ParseBool(v); err == nil { + return b, nil + } + return false, fmt.Errorf("argument %q cannot be converted to bool", key) + case int: + return v != 0, nil + case float64: + return v != 0, nil + default: + return false, fmt.Errorf("argument %q is not a bool", key) + } + } + return false, fmt.Errorf("required argument %q not found", key) +} + +// GetStringSlice returns a string slice argument by key, or the default value if not found +func (r CallToolRequest) GetStringSlice(key string, defaultValue []string) []string { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []string: + return v + case []any: + result := make([]string, 0, len(v)) + for _, item := range v { + if str, ok := item.(string); ok { + result = append(result, str) + } + } + return result + } + } + return defaultValue +} + +// RequireStringSlice returns a string slice argument by key, or an error if not found or not convertible to string slice +func (r CallToolRequest) RequireStringSlice(key string) ([]string, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []string: + return v, nil + case []any: + result := make([]string, 0, len(v)) + for i, item := range v { + if str, ok := item.(string); ok { + result = append(result, str) + } else { + return nil, fmt.Errorf("item %d in argument %q is not a string", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a string slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetIntSlice returns an int slice argument by key, or the default value if not found +func (r CallToolRequest) GetIntSlice(key string, defaultValue []int) []int { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []int: + return v + case []any: + result := make([]int, 0, len(v)) + for _, item := range v { + switch num := item.(type) { + case int: + result = append(result, num) + case float64: + result = append(result, int(num)) + case string: + if i, err := strconv.Atoi(num); err == nil { + result = append(result, i) + } + } + } + return result + } + } + return defaultValue +} + +// RequireIntSlice returns an int slice argument by key, or an error if not found or not convertible to int slice +func (r CallToolRequest) RequireIntSlice(key string) ([]int, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []int: + return v, nil + case []any: + result := make([]int, 0, len(v)) + for i, item := range v { + switch num := item.(type) { + case int: + result = append(result, num) + case float64: + result = append(result, int(num)) + case string: + if i, err := strconv.Atoi(num); err == nil { + result = append(result, i) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to int", i, key) + } + default: + return nil, fmt.Errorf("item %d in argument %q is not an int", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not an int slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetFloatSlice returns a float64 slice argument by key, or the default value if not found +func (r CallToolRequest) GetFloatSlice(key string, defaultValue []float64) []float64 { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []float64: + return v + case []any: + result := make([]float64, 0, len(v)) + for _, item := range v { + switch num := item.(type) { + case float64: + result = append(result, num) + case int: + result = append(result, float64(num)) + case string: + if f, err := strconv.ParseFloat(num, 64); err == nil { + result = append(result, f) + } + } + } + return result + } + } + return defaultValue +} + +// RequireFloatSlice returns a float64 slice argument by key, or an error if not found or not convertible to float64 slice +func (r CallToolRequest) RequireFloatSlice(key string) ([]float64, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []float64: + return v, nil + case []any: + result := make([]float64, 0, len(v)) + for i, item := range v { + switch num := item.(type) { + case float64: + result = append(result, num) + case int: + result = append(result, float64(num)) + case string: + if f, err := strconv.ParseFloat(num, 64); err == nil { + result = append(result, f) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to float64", i, key) + } + default: + return nil, fmt.Errorf("item %d in argument %q is not a float64", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a float64 slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// GetBoolSlice returns a bool slice argument by key, or the default value if not found +func (r CallToolRequest) GetBoolSlice(key string, defaultValue []bool) []bool { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []bool: + return v + case []any: + result := make([]bool, 0, len(v)) + for _, item := range v { + switch b := item.(type) { + case bool: + result = append(result, b) + case string: + if parsed, err := strconv.ParseBool(b); err == nil { + result = append(result, parsed) + } + case int: + result = append(result, b != 0) + case float64: + result = append(result, b != 0) + } + } + return result + } + } + return defaultValue +} + +// RequireBoolSlice returns a bool slice argument by key, or an error if not found or not convertible to bool slice +func (r CallToolRequest) RequireBoolSlice(key string) ([]bool, error) { + args := r.GetArguments() + if val, ok := args[key]; ok { + switch v := val.(type) { + case []bool: + return v, nil + case []any: + result := make([]bool, 0, len(v)) + for i, item := range v { + switch b := item.(type) { + case bool: + result = append(result, b) + case string: + if parsed, err := strconv.ParseBool(b); err == nil { + result = append(result, parsed) + } else { + return nil, fmt.Errorf("item %d in argument %q cannot be converted to bool", i, key) + } + case int: + result = append(result, b != 0) + case float64: + result = append(result, b != 0) + default: + return nil, fmt.Errorf("item %d in argument %q is not a bool", i, key) + } + } + return result, nil + default: + return nil, fmt.Errorf("argument %q is not a bool slice", key) + } + } + return nil, fmt.Errorf("required argument %q not found", key) +} + +// MarshalJSON implements custom JSON marshaling for CallToolResult +func (r CallToolResult) MarshalJSON() ([]byte, error) { + m := make(map[string]any) + + // Marshal Meta if present + if r.Meta != nil { + m["_meta"] = r.Meta + } + + // Marshal Content array + content := make([]any, len(r.Content)) + for i, c := range r.Content { + content[i] = c + } + m["content"] = content + + // Marshal StructuredContent if present + if r.StructuredContent != nil { + m["structuredContent"] = r.StructuredContent + } + + // Marshal IsError if true + if r.IsError { + m["isError"] = r.IsError + } + + return json.Marshal(m) +} + +// UnmarshalJSON implements custom JSON unmarshaling for CallToolResult +func (r *CallToolResult) UnmarshalJSON(data []byte) error { + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Unmarshal Meta + if meta, ok := raw["_meta"]; ok { + if metaMap, ok := meta.(map[string]any); ok { + r.Meta = NewMetaFromMap(metaMap) + } + } + + // Unmarshal Content array + if contentRaw, ok := raw["content"]; ok { + if contentArray, ok := contentRaw.([]any); ok { + r.Content = make([]Content, len(contentArray)) + for i, item := range contentArray { + itemBytes, err := json.Marshal(item) + if err != nil { + return err + } + content, err := UnmarshalContent(itemBytes) + if err != nil { + return err + } + r.Content[i] = content + } + } + } + + // Unmarshal StructuredContent if present + if structured, ok := raw["structuredContent"]; ok { + r.StructuredContent = structured + } + + // Unmarshal IsError + if isError, ok := raw["isError"]; ok { + if isErrorBool, ok := isError.(bool); ok { + r.IsError = isErrorBool + } + } + + return nil +} + +// ToolListChangedNotification is an optional notification from the server to +// the client, informing it that the list of tools it offers has changed. This may +// be issued by servers without any previous subscription from the client. +type ToolListChangedNotification struct { + Notification +} + +// Tool represents the definition for a tool the client can call. +type Tool struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // The name of the tool. + Name string `json:"name"` + // A human-readable description of the tool. + Description string `json:"description,omitempty"` + // A JSON Schema object defining the expected parameters for the tool. + InputSchema ToolInputSchema `json:"inputSchema"` + // Alternative to InputSchema - allows arbitrary JSON Schema to be provided + RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling + // A JSON Schema object defining the expected output returned by the tool . + OutputSchema ToolOutputSchema `json:"outputSchema,omitempty"` + // Optional JSON Schema defining expected output structure + RawOutputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling + // Optional properties describing tool behavior + Annotations ToolAnnotation `json:"annotations"` +} + +// GetName returns the name of the tool. +func (t Tool) GetName() string { + return t.Name +} + +// MarshalJSON implements the json.Marshaler interface for Tool. +// It handles marshaling either InputSchema or RawInputSchema based on which is set. +func (t Tool) MarshalJSON() ([]byte, error) { + // Create a map to build the JSON structure + m := make(map[string]any, 5) + + // Add the name and description + m["name"] = t.Name + if t.Description != "" { + m["description"] = t.Description + } + + // Determine which input schema to use + if t.RawInputSchema != nil { + if t.InputSchema.Type != "" { + return nil, fmt.Errorf("tool %s has both InputSchema and RawInputSchema set: %w", t.Name, errToolSchemaConflict) + } + m["inputSchema"] = t.RawInputSchema + } else { + // Use the structured InputSchema + m["inputSchema"] = t.InputSchema + } + + // Add output schema if present + if t.RawOutputSchema != nil { + if t.OutputSchema.Type != "" { + return nil, fmt.Errorf("tool %s has both OutputSchema and RawOutputSchema set: %w", t.Name, errToolSchemaConflict) + } + m["outputSchema"] = t.RawOutputSchema + } else if t.OutputSchema.Type != "" { // If no output schema is specified, do not return anything + m["outputSchema"] = t.OutputSchema + } + + m["annotations"] = t.Annotations + + return json.Marshal(m) +} + +// ToolArgumentsSchema represents a JSON Schema for tool arguments. +type ToolArgumentsSchema struct { + Defs map[string]any `json:"$defs,omitempty"` + Type string `json:"type"` + Properties map[string]any `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type ToolInputSchema ToolArgumentsSchema // For retro-compatibility +type ToolOutputSchema ToolArgumentsSchema + +// MarshalJSON implements the json.Marshaler interface for ToolInputSchema. +func (tis ToolArgumentsSchema) MarshalJSON() ([]byte, error) { + m := make(map[string]any) + m["type"] = tis.Type + + if tis.Defs != nil { + m["$defs"] = tis.Defs + } + + // Marshal Properties to '{}' rather than `nil` when its length equals zero + if tis.Properties != nil { + m["properties"] = tis.Properties + } + + if len(tis.Required) > 0 { + m["required"] = tis.Required + } + + return json.Marshal(m) +} + +type ToolAnnotation struct { + // Human-readable title for the tool + Title string `json:"title,omitempty"` + // If true, the tool does not modify its environment + ReadOnlyHint *bool `json:"readOnlyHint,omitempty"` + // If true, the tool may perform destructive updates + DestructiveHint *bool `json:"destructiveHint,omitempty"` + // If true, repeated calls with same args have no additional effect + IdempotentHint *bool `json:"idempotentHint,omitempty"` + // If true, tool interacts with external entities + OpenWorldHint *bool `json:"openWorldHint,omitempty"` +} + +// ToolOption is a function that configures a Tool. +// It provides a flexible way to set various properties of a Tool using the functional options pattern. +type ToolOption func(*Tool) + +// PropertyOption is a function that configures a property in a Tool's input schema. +// It allows for flexible configuration of JSON Schema properties using the functional options pattern. +type PropertyOption func(map[string]any) + +// +// Core Tool Functions +// + +// NewTool creates a new Tool with the given name and options. +// The tool will have an object-type input schema with configurable properties. +// Options are applied in order, allowing for flexible tool configuration. +func NewTool(name string, opts ...ToolOption) Tool { + tool := Tool{ + Name: name, + InputSchema: ToolInputSchema{ + Type: "object", + Properties: make(map[string]any), + Required: nil, // Will be omitted from JSON if empty + }, + Annotations: ToolAnnotation{ + Title: "", + ReadOnlyHint: ToBoolPtr(false), + DestructiveHint: ToBoolPtr(true), + IdempotentHint: ToBoolPtr(false), + OpenWorldHint: ToBoolPtr(true), + }, + } + + for _, opt := range opts { + opt(&tool) + } + + return tool +} + +// NewToolWithRawSchema creates a new Tool with the given name and a raw JSON +// Schema. This allows for arbitrary JSON Schema to be used for the tool's input +// schema. +// +// NOTE a [Tool] built in such a way is incompatible with the [ToolOption] and +// runtime errors will result from supplying a [ToolOption] to a [Tool] built +// with this function. +func NewToolWithRawSchema(name, description string, schema json.RawMessage) Tool { + tool := Tool{ + Name: name, + Description: description, + RawInputSchema: schema, + } + + return tool +} + +// WithDescription adds a description to the Tool. +// The description should provide a clear, human-readable explanation of what the tool does. +func WithDescription(description string) ToolOption { + return func(t *Tool) { + t.Description = description + } +} + +// WithInputSchema creates a ToolOption that sets the input schema for a tool. +// It accepts any Go type, usually a struct, and automatically generates a JSON schema from it. +func WithInputSchema[T any]() ToolOption { + return func(t *Tool) { + var zero T + + // Generate schema using invopop/jsonschema library + // Configure reflector to generate clean, MCP-compatible schemas + reflector := jsonschema.Reflector{ + DoNotReference: true, // Removes $defs map, outputs entire structure inline + Anonymous: true, // Hides auto-generated Schema IDs + AllowAdditionalProperties: true, // Removes additionalProperties: false + } + schema := reflector.Reflect(zero) + + // Clean up schema for MCP compliance + schema.Version = "" // Remove $schema field + + // Convert to raw JSON for MCP + mcpSchema, err := json.Marshal(schema) + if err != nil { + // Skip and maintain backward compatibility + return + } + + t.InputSchema.Type = "" + t.RawInputSchema = json.RawMessage(mcpSchema) + } +} + +// WithRawInputSchema sets a raw JSON schema for the tool's input. +// Use this when you need full control over the schema or when working with +// complex schemas that can't be generated from Go types. The jsonschema library +// can handle complex schemas and provides nice extension points, so be sure to +// check that out before using this. +func WithRawInputSchema(schema json.RawMessage) ToolOption { + return func(t *Tool) { + t.RawInputSchema = schema + } +} + +// WithOutputSchema creates a ToolOption that sets the output schema for a tool. +// It accepts any Go type, usually a struct, and automatically generates a JSON schema from it. +func WithOutputSchema[T any]() ToolOption { + return func(t *Tool) { + var zero T + + // Generate schema using invopop/jsonschema library + // Configure reflector to generate clean, MCP-compatible schemas + reflector := jsonschema.Reflector{ + DoNotReference: true, // Removes $defs map, outputs entire structure inline + Anonymous: true, // Hides auto-generated Schema IDs + AllowAdditionalProperties: true, // Removes additionalProperties: false + } + schema := reflector.Reflect(zero) + + // Clean up schema for MCP compliance + schema.Version = "" // Remove $schema field + + // Convert to raw JSON for MCP + mcpSchema, err := json.Marshal(schema) + if err != nil { + // Skip and maintain backward compatibility + return + } + + // Retrieve the schema from raw JSON + if err := json.Unmarshal(mcpSchema, &t.OutputSchema); err != nil { + // Skip and maintain backward compatibility + return + } + + // Always set the type to "object" as of the current MCP spec + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#output-schema + t.OutputSchema.Type = "object" + } +} + +// WithRawOutputSchema sets a raw JSON schema for the tool's output. +// Use this when you need full control over the schema or when working with +// complex schemas that can't be generated from Go types. The jsonschema library +// can handle complex schemas and provides nice extension points, so be sure to +// check that out before using this. +func WithRawOutputSchema(schema json.RawMessage) ToolOption { + return func(t *Tool) { + t.RawOutputSchema = schema + } +} + +// WithToolAnnotation adds optional hints about the Tool. +func WithToolAnnotation(annotation ToolAnnotation) ToolOption { + return func(t *Tool) { + t.Annotations = annotation + } +} + +// WithTitleAnnotation sets the Title field of the Tool's Annotations. +// It provides a human-readable title for the tool. +func WithTitleAnnotation(title string) ToolOption { + return func(t *Tool) { + t.Annotations.Title = title + } +} + +// WithReadOnlyHintAnnotation sets the ReadOnlyHint field of the Tool's Annotations. +// If true, it indicates the tool does not modify its environment. +func WithReadOnlyHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.ReadOnlyHint = &value + } +} + +// WithDestructiveHintAnnotation sets the DestructiveHint field of the Tool's Annotations. +// If true, it indicates the tool may perform destructive updates. +func WithDestructiveHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.DestructiveHint = &value + } +} + +// WithIdempotentHintAnnotation sets the IdempotentHint field of the Tool's Annotations. +// If true, it indicates repeated calls with the same arguments have no additional effect. +func WithIdempotentHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.IdempotentHint = &value + } +} + +// WithOpenWorldHintAnnotation sets the OpenWorldHint field of the Tool's Annotations. +// If true, it indicates the tool interacts with external entities. +func WithOpenWorldHintAnnotation(value bool) ToolOption { + return func(t *Tool) { + t.Annotations.OpenWorldHint = &value + } +} + +// +// Common Property Options +// + +// Description adds a description to a property in the JSON Schema. +// The description should explain the purpose and expected values of the property. +func Description(desc string) PropertyOption { + return func(schema map[string]any) { + schema["description"] = desc + } +} + +// Required marks a property as required in the tool's input schema. +// Required properties must be provided when using the tool. +func Required() PropertyOption { + return func(schema map[string]any) { + schema["required"] = true + } +} + +// Title adds a display-friendly title to a property in the JSON Schema. +// This title can be used by UI components to show a more readable property name. +func Title(title string) PropertyOption { + return func(schema map[string]any) { + schema["title"] = title + } +} + +// +// String Property Options +// + +// DefaultString sets the default value for a string property. +// This value will be used if the property is not explicitly provided. +func DefaultString(value string) PropertyOption { + return func(schema map[string]any) { + schema["default"] = value + } +} + +// Enum specifies a list of allowed values for a string property. +// The property value must be one of the specified enum values. +func Enum(values ...string) PropertyOption { + return func(schema map[string]any) { + schema["enum"] = values + } +} + +// MaxLength sets the maximum length for a string property. +// The string value must not exceed this length. +func MaxLength(max int) PropertyOption { + return func(schema map[string]any) { + schema["maxLength"] = max + } +} + +// MinLength sets the minimum length for a string property. +// The string value must be at least this length. +func MinLength(min int) PropertyOption { + return func(schema map[string]any) { + schema["minLength"] = min + } +} + +// Pattern sets a regex pattern that a string property must match. +// The string value must conform to the specified regular expression. +func Pattern(pattern string) PropertyOption { + return func(schema map[string]any) { + schema["pattern"] = pattern + } +} + +// +// Number Property Options +// + +// DefaultNumber sets the default value for a number property. +// This value will be used if the property is not explicitly provided. +func DefaultNumber(value float64) PropertyOption { + return func(schema map[string]any) { + schema["default"] = value + } +} + +// Max sets the maximum value for a number property. +// The number value must not exceed this maximum. +func Max(max float64) PropertyOption { + return func(schema map[string]any) { + schema["maximum"] = max + } +} + +// Min sets the minimum value for a number property. +// The number value must not be less than this minimum. +func Min(min float64) PropertyOption { + return func(schema map[string]any) { + schema["minimum"] = min + } +} + +// MultipleOf specifies that a number must be a multiple of the given value. +// The number value must be divisible by this value. +func MultipleOf(value float64) PropertyOption { + return func(schema map[string]any) { + schema["multipleOf"] = value + } +} + +// +// Boolean Property Options +// + +// DefaultBool sets the default value for a boolean property. +// This value will be used if the property is not explicitly provided. +func DefaultBool(value bool) PropertyOption { + return func(schema map[string]any) { + schema["default"] = value + } +} + +// +// Array Property Options +// + +// DefaultArray sets the default value for an array property. +// This value will be used if the property is not explicitly provided. +func DefaultArray[T any](value []T) PropertyOption { + return func(schema map[string]any) { + schema["default"] = value + } +} + +// +// Property Type Helpers +// + +// WithBoolean adds a boolean property to the tool schema. +// It accepts property options to configure the boolean property's behavior and constraints. +func WithBoolean(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{ + "type": "boolean", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithNumber adds a number property to the tool schema. +// It accepts property options to configure the number property's behavior and constraints. +func WithNumber(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{ + "type": "number", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithString adds a string property to the tool schema. +// It accepts property options to configure the string property's behavior and constraints. +func WithString(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{ + "type": "string", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithObject adds an object property to the tool schema. +// It accepts property options to configure the object property's behavior and constraints. +func WithObject(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{ + "type": "object", + "properties": map[string]any{}, + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// WithArray adds an array property to the tool schema. +// It accepts property options to configure the array property's behavior and constraints. +func WithArray(name string, opts ...PropertyOption) ToolOption { + return func(t *Tool) { + schema := map[string]any{ + "type": "array", + } + + for _, opt := range opts { + opt(schema) + } + + // Remove required from property schema and add to InputSchema.required + if required, ok := schema["required"].(bool); ok && required { + delete(schema, "required") + t.InputSchema.Required = append(t.InputSchema.Required, name) + } + + t.InputSchema.Properties[name] = schema + } +} + +// Properties defines the properties for an object schema +func Properties(props map[string]any) PropertyOption { + return func(schema map[string]any) { + schema["properties"] = props + } +} + +// AdditionalProperties specifies whether additional properties are allowed in the object +// or defines a schema for additional properties +func AdditionalProperties(schema any) PropertyOption { + return func(schemaMap map[string]any) { + schemaMap["additionalProperties"] = schema + } +} + +// MinProperties sets the minimum number of properties for an object +func MinProperties(min int) PropertyOption { + return func(schema map[string]any) { + schema["minProperties"] = min + } +} + +// MaxProperties sets the maximum number of properties for an object +func MaxProperties(max int) PropertyOption { + return func(schema map[string]any) { + schema["maxProperties"] = max + } +} + +// PropertyNames defines a schema for property names in an object +func PropertyNames(schema map[string]any) PropertyOption { + return func(schemaMap map[string]any) { + schemaMap["propertyNames"] = schema + } +} + +// Items defines the schema for array items. +// Accepts any schema definition for maximum flexibility. +// +// Example: +// +// Items(map[string]any{ +// "type": "object", +// "properties": map[string]any{ +// "name": map[string]any{"type": "string"}, +// "age": map[string]any{"type": "number"}, +// }, +// }) +// +// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead. +func Items(schema any) PropertyOption { + return func(schemaMap map[string]any) { + schemaMap["items"] = schema + } +} + +// MinItems sets the minimum number of items for an array +func MinItems(min int) PropertyOption { + return func(schema map[string]any) { + schema["minItems"] = min + } +} + +// MaxItems sets the maximum number of items for an array +func MaxItems(max int) PropertyOption { + return func(schema map[string]any) { + schema["maxItems"] = max + } +} + +// UniqueItems specifies whether array items must be unique +func UniqueItems(unique bool) PropertyOption { + return func(schema map[string]any) { + schema["uniqueItems"] = unique + } +} + +// WithStringItems configures an array's items to be of type string. +// +// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("tags", mcp.WithStringItems()) +// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue"))) +// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50))) +// +// Limitations: Only supports simple string arrays. Use Items() for complex objects. +func WithStringItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "string", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithStringEnumItems configures an array's items to be of type string with a specified enum. +// Example: +// +// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"})) +// +// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility. +func WithStringEnumItems(values []string) PropertyOption { + return func(schema map[string]any) { + schema["items"] = map[string]any{ + "type": "string", + "enum": values, + } + } +} + +// WithNumberItems configures an array's items to be of type number. +// +// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100))) +// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0))) +// +// Limitations: Only supports simple number arrays. Use Items() for complex objects. +func WithNumberItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "number", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithBooleanItems configures an array's items to be of type boolean. +// +// Supported options: Description(), DefaultBool() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("flags", mcp.WithBooleanItems()) +// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions"))) +// +// Limitations: Only supports simple boolean arrays. Use Items() for complex objects. +func WithBooleanItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "boolean", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go b/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go new file mode 100644 index 0000000000..a03a19dd79 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/typed_tools.go @@ -0,0 +1,42 @@ +package mcp + +import ( + "context" + "fmt" +) + +// TypedToolHandlerFunc is a function that handles a tool call with typed arguments +type TypedToolHandlerFunc[T any] func(ctx context.Context, request CallToolRequest, args T) (*CallToolResult, error) + +// StructuredToolHandlerFunc is a function that handles a tool call with typed arguments and returns structured output +type StructuredToolHandlerFunc[TArgs any, TResult any] func(ctx context.Context, request CallToolRequest, args TArgs) (TResult, error) + +// NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct +func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + var args T + if err := request.BindArguments(&args); err != nil { + return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil + } + return handler(ctx, request, args) + } +} + +// NewStructuredToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct +// and returns structured output. It automatically creates both structured and +// text content (from the structured output) for backwards compatibility. +func NewStructuredToolHandler[TArgs any, TResult any](handler StructuredToolHandlerFunc[TArgs, TResult]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + var args TArgs + if err := request.BindArguments(&args); err != nil { + return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil + } + + result, err := handler(ctx, request, args) + if err != nil { + return NewToolResultError(fmt.Sprintf("tool execution failed: %v", err)), nil + } + + return NewToolResultStructuredOnly(result), nil + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/types.go b/vendor/github.com/mark3labs/mcp-go/mcp/types.go new file mode 100644 index 0000000000..f871b7d9d5 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/types.go @@ -0,0 +1,1173 @@ +// Package mcp defines the core types and interfaces for the Model Context Protocol (MCP). +// MCP is a protocol for communication between LLM-powered applications and their supporting services. +package mcp + +import ( + "encoding/json" + "fmt" + "maps" + "strconv" + + "net/http" + + "github.com/yosida95/uritemplate/v3" +) + +type MCPMethod string + +const ( + // MethodInitialize initiates connection and negotiates protocol capabilities. + // https://modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization + MethodInitialize MCPMethod = "initialize" + + // MethodPing verifies connection liveness between client and server. + // https://modelcontextprotocol.io/specification/2024-11-05/basic/utilities/ping/ + MethodPing MCPMethod = "ping" + + // MethodResourcesList lists all available server resources. + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesList MCPMethod = "resources/list" + + // MethodResourcesTemplatesList provides URI templates for constructing resource URIs. + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesTemplatesList MCPMethod = "resources/templates/list" + + // MethodResourcesRead retrieves content of a specific resource by URI. + // https://modelcontextprotocol.io/specification/2024-11-05/server/resources/ + MethodResourcesRead MCPMethod = "resources/read" + + // MethodPromptsList lists all available prompt templates. + // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + MethodPromptsList MCPMethod = "prompts/list" + + // MethodPromptsGet retrieves a specific prompt template with filled parameters. + // https://modelcontextprotocol.io/specification/2024-11-05/server/prompts/ + MethodPromptsGet MCPMethod = "prompts/get" + + // MethodToolsList lists all available executable tools. + // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ + MethodToolsList MCPMethod = "tools/list" + + // MethodToolsCall invokes a specific tool with provided parameters. + // https://modelcontextprotocol.io/specification/2024-11-05/server/tools/ + MethodToolsCall MCPMethod = "tools/call" + + // MethodSetLogLevel configures the minimum log level for client + // https://modelcontextprotocol.io/specification/2025-03-26/server/utilities/logging + MethodSetLogLevel MCPMethod = "logging/setLevel" + + // MethodNotificationResourcesListChanged notifies when the list of available resources changes. + // https://modelcontextprotocol.io/specification/2025-03-26/server/resources#list-changed-notification + MethodNotificationResourcesListChanged = "notifications/resources/list_changed" + + MethodNotificationResourceUpdated = "notifications/resources/updated" + + // MethodNotificationPromptsListChanged notifies when the list of available prompt templates changes. + // https://modelcontextprotocol.io/specification/2025-03-26/server/prompts#list-changed-notification + MethodNotificationPromptsListChanged = "notifications/prompts/list_changed" + + // MethodNotificationToolsListChanged notifies when the list of available tools changes. + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/list_changed/ + MethodNotificationToolsListChanged = "notifications/tools/list_changed" +) + +type URITemplate struct { + *uritemplate.Template +} + +func (t *URITemplate) MarshalJSON() ([]byte, error) { + return json.Marshal(t.Raw()) +} + +func (t *URITemplate) UnmarshalJSON(data []byte) error { + var raw string + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + template, err := uritemplate.New(raw) + if err != nil { + return err + } + t.Template = template + return nil +} + +/* JSON-RPC types */ + +// JSONRPCMessage represents either a JSONRPCRequest, JSONRPCNotification, JSONRPCResponse, or JSONRPCError +type JSONRPCMessage any + +// LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol. +const LATEST_PROTOCOL_VERSION = "2025-06-18" + +// ValidProtocolVersions lists all known valid MCP protocol versions. +var ValidProtocolVersions = []string{ + LATEST_PROTOCOL_VERSION, + "2025-03-26", + "2024-11-05", +} + +// JSONRPC_VERSION is the version of JSON-RPC used by MCP. +const JSONRPC_VERSION = "2.0" + +// ProgressToken is used to associate progress notifications with the original request. +type ProgressToken any + +// Cursor is an opaque token used to represent a cursor for pagination. +type Cursor string + +// Meta is metadata attached to a request's parameters. This can include fields +// formally defined by the protocol or other arbitrary data. +type Meta struct { + // If specified, the caller is requesting out-of-band progress + // notifications for this request (as represented by + // notifications/progress). The value of this parameter is an + // opaque token that will be attached to any subsequent + // notifications. The receiver is not obligated to provide these + // notifications. + ProgressToken ProgressToken + + // AdditionalFields are any fields present in the Meta that are not + // otherwise defined in the protocol. + AdditionalFields map[string]any +} + +func (m *Meta) MarshalJSON() ([]byte, error) { + raw := make(map[string]any) + if m.ProgressToken != nil { + raw["progressToken"] = m.ProgressToken + } + maps.Copy(raw, m.AdditionalFields) + + return json.Marshal(raw) +} + +func (m *Meta) UnmarshalJSON(data []byte) error { + raw := make(map[string]any) + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + m.ProgressToken = raw["progressToken"] + delete(raw, "progressToken") + m.AdditionalFields = raw + return nil +} + +func NewMetaFromMap(m map[string]any) *Meta { + progressToken := m["progressToken"] + if progressToken != nil { + delete(m, "progressToken") + } + + return &Meta{ + ProgressToken: progressToken, + AdditionalFields: m, + } +} + +type Request struct { + Method string `json:"method"` + Params RequestParams `json:"params,omitempty"` +} + +type RequestParams struct { + Meta *Meta `json:"_meta,omitempty"` +} + +type Params map[string]any + +type Notification struct { + Method string `json:"method"` + Params NotificationParams `json:"params,omitempty"` +} + +type NotificationParams struct { + // This parameter name is reserved by MCP to allow clients and + // servers to attach additional metadata to their notifications. + Meta map[string]any `json:"_meta,omitempty"` + + // Additional fields can be added to this map + AdditionalFields map[string]any `json:"-"` +} + +// MarshalJSON implements custom JSON marshaling +func (p NotificationParams) MarshalJSON() ([]byte, error) { + // Create a map to hold all fields + m := make(map[string]any) + + // Add Meta if it exists + if p.Meta != nil { + m["_meta"] = p.Meta + } + + // Add all additional fields + for k, v := range p.AdditionalFields { + // Ensure we don't override the _meta field + if k != "_meta" { + m[k] = v + } + } + + return json.Marshal(m) +} + +// UnmarshalJSON implements custom JSON unmarshaling +func (p *NotificationParams) UnmarshalJSON(data []byte) error { + // Create a map to hold all fields + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return err + } + + // Initialize maps if they're nil + if p.Meta == nil { + p.Meta = make(map[string]any) + } + if p.AdditionalFields == nil { + p.AdditionalFields = make(map[string]any) + } + + // Process all fields + for k, v := range m { + if k == "_meta" { + // Handle Meta field + if meta, ok := v.(map[string]any); ok { + p.Meta = meta + } + } else { + // Handle additional fields + p.AdditionalFields[k] = v + } + } + + return nil +} + +type Result struct { + // This result property is reserved by the protocol to allow clients and + // servers to attach additional metadata to their responses. + Meta *Meta `json:"_meta,omitempty"` +} + +// RequestId is a uniquely identifying ID for a request in JSON-RPC. +// It can be any JSON-serializable value, typically a number or string. +type RequestId struct { + value any +} + +// NewRequestId creates a new RequestId with the given value +func NewRequestId(value any) RequestId { + return RequestId{value: value} +} + +// Value returns the underlying value of the RequestId +func (r RequestId) Value() any { + return r.value +} + +// String returns a string representation of the RequestId +func (r RequestId) String() string { + switch v := r.value.(type) { + case string: + return "string:" + v + case int64: + return "int64:" + strconv.FormatInt(v, 10) + case float64: + if v == float64(int64(v)) { + return "int64:" + strconv.FormatInt(int64(v), 10) + } + return "float64:" + strconv.FormatFloat(v, 'f', -1, 64) + case nil: + return "" + default: + return "unknown:" + fmt.Sprintf("%v", v) + } +} + +// IsNil returns true if the RequestId is nil +func (r RequestId) IsNil() bool { + return r.value == nil +} + +func (r RequestId) MarshalJSON() ([]byte, error) { + return json.Marshal(r.value) +} + +func (r *RequestId) UnmarshalJSON(data []byte) error { + + if string(data) == "null" { + r.value = nil + return nil + } + + // Try unmarshaling as string first + var s string + if err := json.Unmarshal(data, &s); err == nil { + r.value = s + return nil + } + + // JSON numbers are unmarshaled as float64 in Go + var f float64 + if err := json.Unmarshal(data, &f); err == nil { + if f == float64(int64(f)) { + r.value = int64(f) + } else { + r.value = f + } + return nil + } + + return fmt.Errorf("invalid request id: %s", string(data)) +} + +// JSONRPCRequest represents a request that expects a response. +type JSONRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Params any `json:"params,omitempty"` + Request +} + +// JSONRPCNotification represents a notification which does not expect a response. +type JSONRPCNotification struct { + JSONRPC string `json:"jsonrpc"` + Notification +} + +// JSONRPCResponse represents a successful (non-error) response to a request. +type JSONRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Result any `json:"result"` +} + +// JSONRPCError represents a non-successful (error) response to a request. +type JSONRPCError struct { + JSONRPC string `json:"jsonrpc"` + ID RequestId `json:"id"` + Error struct { + // The error type that occurred. + Code int `json:"code"` + // A short description of the error. The message SHOULD be limited + // to a concise single sentence. + Message string `json:"message"` + // Additional information about the error. The value of this member + // is defined by the sender (e.g. detailed error information, nested errors etc.). + Data any `json:"data,omitempty"` + } `json:"error"` +} + +// Standard JSON-RPC error codes +const ( + PARSE_ERROR = -32700 + INVALID_REQUEST = -32600 + METHOD_NOT_FOUND = -32601 + INVALID_PARAMS = -32602 + INTERNAL_ERROR = -32603 +) + +// MCP error codes +const ( + RESOURCE_NOT_FOUND = -32002 +) + +/* Empty result */ + +// EmptyResult represents a response that indicates success but carries no data. +type EmptyResult Result + +/* Cancellation */ + +// CancelledNotification can be sent by either side to indicate that it is +// cancelling a previously-issued request. +// +// The request SHOULD still be in-flight, but due to communication latency, it +// is always possible that this notification MAY arrive after the request has +// already finished. +// +// This notification indicates that the result will be unused, so any +// associated processing SHOULD cease. +// +// A client MUST NOT attempt to cancel its `initialize` request. +type CancelledNotification struct { + Notification + Params CancelledNotificationParams `json:"params"` +} + +type CancelledNotificationParams struct { + // The ID of the request to cancel. + // + // This MUST correspond to the ID of a request previously issued + // in the same direction. + RequestId RequestId `json:"requestId"` + + // An optional string describing the reason for the cancellation. This MAY + // be logged or presented to the user. + Reason string `json:"reason,omitempty"` +} + +/* Initialization */ + +// InitializeRequest is sent from the client to the server when it first +// connects, asking it to begin initialization. +type InitializeRequest struct { + Request + Params InitializeParams `json:"params"` + Header http.Header `json:"-"` +} + +type InitializeParams struct { + // The latest version of the Model Context Protocol that the client supports. + // The client MAY decide to support older versions as well. + ProtocolVersion string `json:"protocolVersion"` + Capabilities ClientCapabilities `json:"capabilities"` + ClientInfo Implementation `json:"clientInfo"` +} + +// InitializeResult is sent after receiving an initialize request from the +// client. +type InitializeResult struct { + Result + // The version of the Model Context Protocol that the server wants to use. + // This may not match the version that the client requested. If the client cannot + // support this version, it MUST disconnect. + ProtocolVersion string `json:"protocolVersion"` + Capabilities ServerCapabilities `json:"capabilities"` + ServerInfo Implementation `json:"serverInfo"` + // Instructions describing how to use the server and its features. + // + // This can be used by clients to improve the LLM's understanding of + // available tools, resources, etc. It can be thought of like a "hint" to the model. + // For example, this information MAY be added to the system prompt. + Instructions string `json:"instructions,omitempty"` +} + +// InitializedNotification is sent from the client to the server after +// initialization has finished. +type InitializedNotification struct { + Notification +} + +// ClientCapabilities represents capabilities a client may support. Known +// capabilities are defined here, in this schema, but this is not a closed set: any +// client can define its own, additional capabilities. +type ClientCapabilities struct { + // Experimental, non-standard capabilities that the client supports. + Experimental map[string]any `json:"experimental,omitempty"` + // Present if the client supports listing roots. + Roots *struct { + // Whether the client supports notifications for changes to the roots list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"roots,omitempty"` + // Present if the client supports sampling from an LLM. + Sampling *struct{} `json:"sampling,omitempty"` +} + +// ServerCapabilities represents capabilities that a server may support. Known +// capabilities are defined here, in this schema, but this is not a closed set: any +// server can define its own, additional capabilities. +type ServerCapabilities struct { + // Experimental, non-standard capabilities that the server supports. + Experimental map[string]any `json:"experimental,omitempty"` + // Present if the server supports sending log messages to the client. + Logging *struct{} `json:"logging,omitempty"` + // Present if the server offers any prompt templates. + Prompts *struct { + // Whether this server supports notifications for changes to the prompt list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"prompts,omitempty"` + // Present if the server offers any resources to read. + Resources *struct { + // Whether this server supports subscribing to resource updates. + Subscribe bool `json:"subscribe,omitempty"` + // Whether this server supports notifications for changes to the resource + // list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"resources,omitempty"` + // Present if the server supports sending sampling requests to clients. + Sampling *struct{} `json:"sampling,omitempty"` + // Present if the server offers any tools to call. + Tools *struct { + // Whether this server supports notifications for changes to the tool list. + ListChanged bool `json:"listChanged,omitempty"` + } `json:"tools,omitempty"` +} + +// Implementation describes the name and version of an MCP implementation. +type Implementation struct { + Name string `json:"name"` + Version string `json:"version"` +} + +/* Ping */ + +// PingRequest represents a ping, issued by either the server or the client, +// to check that the other party is still alive. The receiver must promptly respond, +// or else may be disconnected. +type PingRequest struct { + Request + Header http.Header `json:"-"` +} + +/* Progress notifications */ + +// ProgressNotification is an out-of-band notification used to inform the +// receiver of a progress update for a long-running request. +type ProgressNotification struct { + Notification + Params ProgressNotificationParams `json:"params"` +} + +type ProgressNotificationParams struct { + // The progress token which was given in the initial request, used to + // associate this notification with the request that is proceeding. + ProgressToken ProgressToken `json:"progressToken"` + // The progress thus far. This should increase every time progress is made, + // even if the total is unknown. + Progress float64 `json:"progress"` + // Total number of items to process (or total progress required), if known. + Total float64 `json:"total,omitempty"` + // Message related to progress. This should provide relevant human-readable + // progress information. + Message string `json:"message,omitempty"` +} + +/* Pagination */ + +type PaginatedRequest struct { + Request + Params PaginatedParams `json:"params,omitempty"` +} + +type PaginatedParams struct { + // An opaque token representing the current pagination position. + // If provided, the server should return results starting after this cursor. + Cursor Cursor `json:"cursor,omitempty"` +} + +type PaginatedResult struct { + Result + // An opaque token representing the pagination position after the last + // returned result. + // If present, there may be more results available. + NextCursor Cursor `json:"nextCursor,omitempty"` +} + +/* Resources */ + +// ListResourcesRequest is sent from the client to request a list of resources +// the server has. +type ListResourcesRequest struct { + PaginatedRequest + Header http.Header `json:"-"` +} + +// ListResourcesResult is the server's response to a resources/list request +// from the client. +type ListResourcesResult struct { + PaginatedResult + Resources []Resource `json:"resources"` +} + +// ListResourceTemplatesRequest is sent from the client to request a list of +// resource templates the server has. +type ListResourceTemplatesRequest struct { + PaginatedRequest + Header http.Header `json:"-"` +} + +// ListResourceTemplatesResult is the server's response to a +// resources/templates/list request from the client. +type ListResourceTemplatesResult struct { + PaginatedResult + ResourceTemplates []ResourceTemplate `json:"resourceTemplates"` +} + +// ReadResourceRequest is sent from the client to the server, to read a +// specific resource URI. +type ReadResourceRequest struct { + Request + Header http.Header `json:"-"` + Params ReadResourceParams `json:"params"` +} + +type ReadResourceParams struct { + // The URI of the resource to read. The URI can use any protocol; it is up + // to the server how to interpret it. + URI string `json:"uri"` + // Arguments to pass to the resource handler + Arguments map[string]any `json:"arguments,omitempty"` +} + +// ReadResourceResult is the server's response to a resources/read request +// from the client. +type ReadResourceResult struct { + Result + Contents []ResourceContents `json:"contents"` // Can be TextResourceContents or BlobResourceContents +} + +// ResourceListChangedNotification is an optional notification from the server +// to the client, informing it that the list of resources it can read from has +// changed. This may be issued by servers without any previous subscription from +// the client. +type ResourceListChangedNotification struct { + Notification +} + +// SubscribeRequest is sent from the client to request resources/updated +// notifications from the server whenever a particular resource changes. +type SubscribeRequest struct { + Request + Params SubscribeParams `json:"params"` + Header http.Header `json:"-"` +} + +type SubscribeParams struct { + // The URI of the resource to subscribe to. The URI can use any protocol; it + // is up to the server how to interpret it. + URI string `json:"uri"` +} + +// UnsubscribeRequest is sent from the client to request cancellation of +// resources/updated notifications from the server. This should follow a previous +// resources/subscribe request. +type UnsubscribeRequest struct { + Request + Params UnsubscribeParams `json:"params"` + Header http.Header `json:"-"` +} + +type UnsubscribeParams struct { + // The URI of the resource to unsubscribe from. + URI string `json:"uri"` +} + +// ResourceUpdatedNotification is a notification from the server to the client, +// informing it that a resource has changed and may need to be read again. This +// should only be sent if the client previously sent a resources/subscribe request. +type ResourceUpdatedNotification struct { + Notification + Params ResourceUpdatedNotificationParams `json:"params"` +} +type ResourceUpdatedNotificationParams struct { + // The URI of the resource that has been updated. This might be a sub- + // resource of the one that the client actually subscribed to. + URI string `json:"uri"` +} + +// Resource represents a known resource that the server is capable of reading. +type Resource struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // The URI of this resource. + URI string `json:"uri"` + // A human-readable name for this resource. + // + // This can be used by clients to populate UI elements. + Name string `json:"name"` + // A description of what this resource represents. + // + // This can be used by clients to improve the LLM's understanding of + // available resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` +} + +// GetName returns the name of the resource. +func (r Resource) GetName() string { + return r.Name +} + +// ResourceTemplate represents a template description for resources available +// on the server. +type ResourceTemplate struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // A URI template (according to RFC 6570) that can be used to construct + // resource URIs. + URITemplate *URITemplate `json:"uriTemplate"` + // A human-readable name for the type of resource this template refers to. + // + // This can be used by clients to populate UI elements. + Name string `json:"name"` + // A description of what this template is for. + // + // This can be used by clients to improve the LLM's understanding of + // available resources. It can be thought of like a "hint" to the model. + Description string `json:"description,omitempty"` + // The MIME type for all resources that match this template. This should only + // be included if all resources matching this template have the same type. + MIMEType string `json:"mimeType,omitempty"` +} + +// GetName returns the name of the resourceTemplate. +func (rt ResourceTemplate) GetName() string { + return rt.Name +} + +// ResourceContents represents the contents of a specific resource or sub- +// resource. +type ResourceContents interface { + isResourceContents() +} + +type TextResourceContents struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // The URI of this resource. + URI string `json:"uri"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // The text of the item. This must only be set if the item can actually be + // represented as text (not binary data). + Text string `json:"text"` +} + +func (TextResourceContents) isResourceContents() {} + +type BlobResourceContents struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // The URI of this resource. + URI string `json:"uri"` + // The MIME type of this resource, if known. + MIMEType string `json:"mimeType,omitempty"` + // A base64-encoded string representing the binary data of the item. + Blob string `json:"blob"` +} + +func (BlobResourceContents) isResourceContents() {} + +/* Logging */ + +// SetLevelRequest is a request from the client to the server, to enable or +// adjust logging. +type SetLevelRequest struct { + Request + Params SetLevelParams `json:"params"` + Header http.Header `json:"-"` +} + +type SetLevelParams struct { + // The level of logging that the client wants to receive from the server. + // The server should send all logs at this level and higher (i.e., more severe) to + // the client as notifications/logging/message. + Level LoggingLevel `json:"level"` +} + +// LoggingMessageNotification is a notification of a log message passed from +// server to client. If no logging/setLevel request has been sent from the client, +// the server MAY decide which messages to send automatically. +type LoggingMessageNotification struct { + Notification + Params LoggingMessageNotificationParams `json:"params"` +} + +type LoggingMessageNotificationParams struct { + // The severity of this log message. + Level LoggingLevel `json:"level"` + // An optional name of the logger issuing this message. + Logger string `json:"logger,omitempty"` + // The data to be logged, such as a string message or an object. Any JSON + // serializable type is allowed here. + Data any `json:"data"` +} + +// LoggingLevel represents the severity of a log message. +// +// These map to syslog message severities, as specified in RFC-5424: +// https://datatracker.ietf.org/doc/html/rfc5424#section-6.2.1 +type LoggingLevel string + +const ( + LoggingLevelDebug LoggingLevel = "debug" + LoggingLevelInfo LoggingLevel = "info" + LoggingLevelNotice LoggingLevel = "notice" + LoggingLevelWarning LoggingLevel = "warning" + LoggingLevelError LoggingLevel = "error" + LoggingLevelCritical LoggingLevel = "critical" + LoggingLevelAlert LoggingLevel = "alert" + LoggingLevelEmergency LoggingLevel = "emergency" +) + +var levelToInt = map[LoggingLevel]int{ + LoggingLevelDebug: 0, + LoggingLevelInfo: 1, + LoggingLevelNotice: 2, + LoggingLevelWarning: 3, + LoggingLevelError: 4, + LoggingLevelCritical: 5, + LoggingLevelAlert: 6, + LoggingLevelEmergency: 7, +} + +func (l LoggingLevel) ShouldSendTo(minLevel LoggingLevel) bool { + ia, oka := levelToInt[l] + ib, okb := levelToInt[minLevel] + if !oka || !okb { + return false + } + return ia >= ib +} + +/* Sampling */ + +const ( + // MethodSamplingCreateMessage allows servers to request LLM completions from clients + MethodSamplingCreateMessage MCPMethod = "sampling/createMessage" +) + +// CreateMessageRequest is a request from the server to sample an LLM via the +// client. The client has full discretion over which model to select. The client +// should also inform the user before beginning sampling, to allow them to inspect +// the request (human in the loop) and decide whether to approve it. +type CreateMessageRequest struct { + Request + CreateMessageParams `json:"params"` +} + +type CreateMessageParams struct { + Messages []SamplingMessage `json:"messages"` + ModelPreferences *ModelPreferences `json:"modelPreferences,omitempty"` + SystemPrompt string `json:"systemPrompt,omitempty"` + IncludeContext string `json:"includeContext,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"maxTokens"` + StopSequences []string `json:"stopSequences,omitempty"` + Metadata any `json:"metadata,omitempty"` +} + +// CreateMessageResult is the client's response to a sampling/create_message +// request from the server. The client should inform the user before returning the +// sampled message, to allow them to inspect the response (human in the loop) and +// decide whether to allow the server to see it. +type CreateMessageResult struct { + Result + SamplingMessage + // The name of the model that generated the message. + Model string `json:"model"` + // The reason why sampling stopped, if known. + StopReason string `json:"stopReason,omitempty"` +} + +// SamplingMessage describes a message issued to or received from an LLM API. +type SamplingMessage struct { + Role Role `json:"role"` + Content any `json:"content"` // Can be TextContent, ImageContent or AudioContent +} + +type Annotations struct { + // Describes who the intended customer of this object or data is. + // + // It can include multiple entries to indicate content useful for multiple + // audiences (e.g., `["user", "assistant"]`). + Audience []Role `json:"audience,omitempty"` + + // Describes how important this data is for operating the server. + // + // A value of 1 means "most important," and indicates that the data is + // effectively required, while 0 means "least important," and indicates that + // the data is entirely optional. + Priority float64 `json:"priority,omitempty"` +} + +// Annotated is the base for objects that include optional annotations for the +// client. The client can use annotations to inform how objects are used or +// displayed +type Annotated struct { + Annotations *Annotations `json:"annotations,omitempty"` +} + +type Content interface { + isContent() +} + +// TextContent represents text provided to or from an LLM. +// It must have Type set to "text". +type TextContent struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + Type string `json:"type"` // Must be "text" + // The text content of the message. + Text string `json:"text"` +} + +func (TextContent) isContent() {} + +// ImageContent represents an image provided to or from an LLM. +// It must have Type set to "image". +type ImageContent struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + Type string `json:"type"` // Must be "image" + // The base64-encoded image data. + Data string `json:"data"` + // The MIME type of the image. Different providers may support different image types. + MIMEType string `json:"mimeType"` +} + +func (ImageContent) isContent() {} + +// AudioContent represents the contents of audio, embedded into a prompt or tool call result. +// It must have Type set to "audio". +type AudioContent struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + Type string `json:"type"` // Must be "audio" + // The base64-encoded audio data. + Data string `json:"data"` + // The MIME type of the audio. Different providers may support different audio types. + MIMEType string `json:"mimeType"` +} + +func (AudioContent) isContent() {} + +// ResourceLink represents a link to a resource that the client can access. +type ResourceLink struct { + Annotated + Type string `json:"type"` // Must be "resource_link" + // The URI of the resource. + URI string `json:"uri"` + // The name of the resource. + Name string `json:"name"` + // The description of the resource. + Description string `json:"description"` + // The MIME type of the resource. + MIMEType string `json:"mimeType"` +} + +func (ResourceLink) isContent() {} + +// EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. +// +// It is up to the client how best to render embedded resources for the +// benefit of the LLM and/or the user. +type EmbeddedResource struct { + Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + Type string `json:"type"` + Resource ResourceContents `json:"resource"` +} + +func (EmbeddedResource) isContent() {} + +// ModelPreferences represents the server's preferences for model selection, +// requested of the client during sampling. +// +// Because LLMs can vary along multiple dimensions, choosing the "best" modelis +// rarely straightforward. Different models excel in different areas—some are +// faster but less capable, others are more capable but more expensive, and so +// on. This interface allows servers to express their priorities across multiple +// dimensions to help clients make an appropriate selection for their use case. +// +// These preferences are always advisory. The client MAY ignore them. It is also +// up to the client to decide how to interpret these preferences and how to +// balance them against other considerations. +type ModelPreferences struct { + // Optional hints to use for model selection. + // + // If multiple hints are specified, the client MUST evaluate them in order + // (such that the first match is taken). + // + // The client SHOULD prioritize these hints over the numeric priorities, but + // MAY still use the priorities to select from ambiguous matches. + Hints []ModelHint `json:"hints,omitempty"` + + // How much to prioritize cost when selecting a model. A value of 0 means cost + // is not important, while a value of 1 means cost is the most important + // factor. + CostPriority float64 `json:"costPriority,omitempty"` + + // How much to prioritize sampling speed (latency) when selecting a model. A + // value of 0 means speed is not important, while a value of 1 means speed is + // the most important factor. + SpeedPriority float64 `json:"speedPriority,omitempty"` + + // How much to prioritize intelligence and capabilities when selecting a + // model. A value of 0 means intelligence is not important, while a value of 1 + // means intelligence is the most important factor. + IntelligencePriority float64 `json:"intelligencePriority,omitempty"` +} + +// ModelHint represents hints to use for model selection. +// +// Keys not declared here are currently left unspecified by the spec and are up +// to the client to interpret. +type ModelHint struct { + // A hint for a model name. + // + // The client SHOULD treat this as a substring of a model name; for example: + // - `claude-3-5-sonnet` should match `claude-3-5-sonnet-20241022` + // - `sonnet` should match `claude-3-5-sonnet-20241022`, `claude-3-sonnet-20240229`, etc. + // - `claude` should match any Claude model + // + // The client MAY also map the string to a different provider's model name or + // a different model family, as long as it fills a similar niche; for example: + // - `gemini-1.5-flash` could match `claude-3-haiku-20240307` + Name string `json:"name,omitempty"` +} + +/* Autocomplete */ + +// CompleteRequest is a request from the client to the server, to ask for completion options. +type CompleteRequest struct { + Request + Params CompleteParams `json:"params"` + Header http.Header `json:"-"` +} + +type CompleteParams struct { + Ref any `json:"ref"` // Can be PromptReference or ResourceReference + Argument struct { + // The name of the argument + Name string `json:"name"` + // The value of the argument to use for completion matching. + Value string `json:"value"` + } `json:"argument"` +} + +// CompleteResult is the server's response to a completion/complete request +type CompleteResult struct { + Result + Completion struct { + // An array of completion values. Must not exceed 100 items. + Values []string `json:"values"` + // The total number of completion options available. This can exceed the + // number of values actually sent in the response. + Total int `json:"total,omitempty"` + // Indicates whether there are additional completion options beyond those + // provided in the current response, even if the exact total is unknown. + HasMore bool `json:"hasMore,omitempty"` + } `json:"completion"` +} + +// ResourceReference is a reference to a resource or resource template definition. +type ResourceReference struct { + Type string `json:"type"` + // The URI or URI template of the resource. + URI string `json:"uri"` +} + +// PromptReference identifies a prompt. +type PromptReference struct { + Type string `json:"type"` + // The name of the prompt or prompt template + Name string `json:"name"` +} + +/* Roots */ + +// ListRootsRequest is sent from the server to request a list of root URIs from the client. Roots allow +// servers to ask for specific directories or files to operate on. A common example +// for roots is providing a set of repositories or directories a server should operate +// on. +// +// This request is typically used when the server needs to understand the file system +// structure or access specific locations that the client has permission to read from. +type ListRootsRequest struct { + Request + Header http.Header `json:"-"` +} + +// ListRootsResult is the client's response to a roots/list request from the server. +// This result contains an array of Root objects, each representing a root directory +// or file that the server can operate on. +type ListRootsResult struct { + Result + Roots []Root `json:"roots"` +} + +// Root represents a root directory or file that the server can operate on. +type Root struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` + // The URI identifying the root. This *must* start with file:// for now. + // This restriction may be relaxed in future versions of the protocol to allow + // other URI schemes. + URI string `json:"uri"` + // An optional name for the root. This can be used to provide a human-readable + // identifier for the root, which may be useful for display purposes or for + // referencing the root in other parts of the application. + Name string `json:"name,omitempty"` +} + +// RootsListChangedNotification is a notification from the client to the +// server, informing it that the list of roots has changed. +// This notification should be sent whenever the client adds, removes, or modifies any root. +// The server should then request an updated list of roots using the ListRootsRequest. +type RootsListChangedNotification struct { + Notification +} + +// ClientRequest represents any request that can be sent from client to server. +type ClientRequest any + +// ClientNotification represents any notification that can be sent from client to server. +type ClientNotification any + +// ClientResult represents any result that can be sent from client to server. +type ClientResult any + +// ServerRequest represents any request that can be sent from server to client. +type ServerRequest any + +// ServerNotification represents any notification that can be sent from server to client. +type ServerNotification any + +// ServerResult represents any result that can be sent from server to client. +type ServerResult any + +type Named interface { + GetName() string +} + +// MarshalJSON implements custom JSON marshaling for Content interface +func MarshalContent(content Content) ([]byte, error) { + return json.Marshal(content) +} + +// UnmarshalContent implements custom JSON unmarshaling for Content interface +func UnmarshalContent(data []byte) (Content, error) { + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + contentType, ok := raw["type"].(string) + if !ok { + return nil, fmt.Errorf("missing or invalid type field") + } + + switch contentType { + case ContentTypeText: + var content TextContent + err := json.Unmarshal(data, &content) + return content, err + case ContentTypeImage: + var content ImageContent + err := json.Unmarshal(data, &content) + return content, err + case ContentTypeAudio: + var content AudioContent + err := json.Unmarshal(data, &content) + return content, err + case ContentTypeLink: + var content ResourceLink + err := json.Unmarshal(data, &content) + return content, err + case ContentTypeResource: + var content EmbeddedResource + err := json.Unmarshal(data, &content) + return content, err + default: + return nil, fmt.Errorf("unknown content type: %s", contentType) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/mcp/utils.go b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go new file mode 100644 index 0000000000..b8deeae9c5 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/mcp/utils.go @@ -0,0 +1,863 @@ +package mcp + +import ( + "encoding/json" + "fmt" + + "github.com/spf13/cast" +) + +// ClientRequest types +var _ ClientRequest = &PingRequest{} +var _ ClientRequest = &InitializeRequest{} +var _ ClientRequest = &CompleteRequest{} +var _ ClientRequest = &SetLevelRequest{} +var _ ClientRequest = &GetPromptRequest{} +var _ ClientRequest = &ListPromptsRequest{} +var _ ClientRequest = &ListResourcesRequest{} +var _ ClientRequest = &ReadResourceRequest{} +var _ ClientRequest = &SubscribeRequest{} +var _ ClientRequest = &UnsubscribeRequest{} +var _ ClientRequest = &CallToolRequest{} +var _ ClientRequest = &ListToolsRequest{} + +// ClientNotification types +var _ ClientNotification = &CancelledNotification{} +var _ ClientNotification = &ProgressNotification{} +var _ ClientNotification = &InitializedNotification{} +var _ ClientNotification = &RootsListChangedNotification{} + +// ClientResult types +var _ ClientResult = &EmptyResult{} +var _ ClientResult = &CreateMessageResult{} +var _ ClientResult = &ListRootsResult{} + +// ServerRequest types +var _ ServerRequest = &PingRequest{} +var _ ServerRequest = &CreateMessageRequest{} +var _ ServerRequest = &ListRootsRequest{} + +// ServerNotification types +var _ ServerNotification = &CancelledNotification{} +var _ ServerNotification = &ProgressNotification{} +var _ ServerNotification = &LoggingMessageNotification{} +var _ ServerNotification = &ResourceUpdatedNotification{} +var _ ServerNotification = &ResourceListChangedNotification{} +var _ ServerNotification = &ToolListChangedNotification{} +var _ ServerNotification = &PromptListChangedNotification{} + +// ServerResult types +var _ ServerResult = &EmptyResult{} +var _ ServerResult = &InitializeResult{} +var _ ServerResult = &CompleteResult{} +var _ ServerResult = &GetPromptResult{} +var _ ServerResult = &ListPromptsResult{} +var _ ServerResult = &ListResourcesResult{} +var _ ServerResult = &ReadResourceResult{} +var _ ServerResult = &CallToolResult{} +var _ ServerResult = &ListToolsResult{} + +// Helper functions for type assertions + +// asType attempts to cast the given interface to the given type +func asType[T any](content any) (*T, bool) { + tc, ok := content.(T) + if !ok { + return nil, false + } + return &tc, true +} + +// AsTextContent attempts to cast the given interface to TextContent +func AsTextContent(content any) (*TextContent, bool) { + return asType[TextContent](content) +} + +// AsImageContent attempts to cast the given interface to ImageContent +func AsImageContent(content any) (*ImageContent, bool) { + return asType[ImageContent](content) +} + +// AsAudioContent attempts to cast the given interface to AudioContent +func AsAudioContent(content any) (*AudioContent, bool) { + return asType[AudioContent](content) +} + +// AsEmbeddedResource attempts to cast the given interface to EmbeddedResource +func AsEmbeddedResource(content any) (*EmbeddedResource, bool) { + return asType[EmbeddedResource](content) +} + +// AsTextResourceContents attempts to cast the given interface to TextResourceContents +func AsTextResourceContents(content any) (*TextResourceContents, bool) { + return asType[TextResourceContents](content) +} + +// AsBlobResourceContents attempts to cast the given interface to BlobResourceContents +func AsBlobResourceContents(content any) (*BlobResourceContents, bool) { + return asType[BlobResourceContents](content) +} + +// Helper function for JSON-RPC + +// NewJSONRPCResponse creates a new JSONRPCResponse with the given id and result +func NewJSONRPCResponse(id RequestId, result Result) JSONRPCResponse { + return JSONRPCResponse{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Result: result, + } +} + +// NewJSONRPCError creates a new JSONRPCResponse with the given id, code, and message +func NewJSONRPCError( + id RequestId, + code int, + message string, + data any, +) JSONRPCError { + return JSONRPCError{ + JSONRPC: JSONRPC_VERSION, + ID: id, + Error: struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` + }{ + Code: code, + Message: message, + Data: data, + }, + } +} + +// NewProgressNotification +// Helper function for creating a progress notification +func NewProgressNotification( + token ProgressToken, + progress float64, + total *float64, + message *string, +) ProgressNotification { + notification := ProgressNotification{ + Notification: Notification{ + Method: "notifications/progress", + }, + Params: struct { + ProgressToken ProgressToken `json:"progressToken"` + Progress float64 `json:"progress"` + Total float64 `json:"total,omitempty"` + Message string `json:"message,omitempty"` + }{ + ProgressToken: token, + Progress: progress, + }, + } + if total != nil { + notification.Params.Total = *total + } + if message != nil { + notification.Params.Message = *message + } + return notification +} + +// NewLoggingMessageNotification +// Helper function for creating a logging message notification +func NewLoggingMessageNotification( + level LoggingLevel, + logger string, + data any, +) LoggingMessageNotification { + return LoggingMessageNotification{ + Notification: Notification{ + Method: "notifications/message", + }, + Params: struct { + Level LoggingLevel `json:"level"` + Logger string `json:"logger,omitempty"` + Data any `json:"data"` + }{ + Level: level, + Logger: logger, + Data: data, + }, + } +} + +// NewPromptMessage +// Helper function to create a new PromptMessage +func NewPromptMessage(role Role, content Content) PromptMessage { + return PromptMessage{ + Role: role, + Content: content, + } +} + +// NewTextContent +// Helper function to create a new TextContent +func NewTextContent(text string) TextContent { + return TextContent{ + Type: ContentTypeText, + Text: text, + } +} + +// NewImageContent +// Helper function to create a new ImageContent +func NewImageContent(data, mimeType string) ImageContent { + return ImageContent{ + Type: ContentTypeImage, + Data: data, + MIMEType: mimeType, + } +} + +// Helper function to create a new AudioContent +func NewAudioContent(data, mimeType string) AudioContent { + return AudioContent{ + Type: ContentTypeAudio, + Data: data, + MIMEType: mimeType, + } +} + +// Helper function to create a new ResourceLink +func NewResourceLink(uri, name, description, mimeType string) ResourceLink { + return ResourceLink{ + Type: ContentTypeLink, + URI: uri, + Name: name, + Description: description, + MIMEType: mimeType, + } +} + +// Helper function to create a new EmbeddedResource +func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { + return EmbeddedResource{ + Type: ContentTypeResource, + Resource: resource, + } +} + +// NewToolResultText creates a new CallToolResult with a text content +func NewToolResultText(text string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + }, + } +} + +// NewToolResultStructured creates a new CallToolResult with structured content. +// It includes both the structured content and a text representation for backward compatibility. +func NewToolResultStructured(structured any, fallbackText string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: fallbackText, + }, + }, + StructuredContent: structured, + } +} + +// NewToolResultStructuredOnly creates a new CallToolResult with structured +// content and creates a JSON string fallback for backwards compatibility. +// This is useful when you want to provide structured data without any specific text fallback. +func NewToolResultStructuredOnly(structured any) *CallToolResult { + var fallbackText string + // Convert to JSON string for backward compatibility + jsonBytes, err := json.Marshal(structured) + if err != nil { + fallbackText = fmt.Sprintf("Error serializing structured content: %v", err) + } else { + fallbackText = string(jsonBytes) + } + + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: fallbackText, + }, + }, + StructuredContent: structured, + } +} + +// NewToolResultImage creates a new CallToolResult with both text and image content +func NewToolResultImage(text, imageData, mimeType string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + ImageContent{ + Type: ContentTypeImage, + Data: imageData, + MIMEType: mimeType, + }, + }, + } +} + +// NewToolResultAudio creates a new CallToolResult with both text and audio content +func NewToolResultAudio(text, imageData, mimeType string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + AudioContent{ + Type: ContentTypeAudio, + Data: imageData, + MIMEType: mimeType, + }, + }, + } +} + +// NewToolResultResource creates a new CallToolResult with an embedded resource +func NewToolResultResource( + text string, + resource ResourceContents, +) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + EmbeddedResource{ + Type: ContentTypeResource, + Resource: resource, + }, + }, + } +} + +// NewToolResultError creates a new CallToolResult with an error message. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultError(text string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + }, + IsError: true, + } +} + +// NewToolResultErrorFromErr creates a new CallToolResult with an error message. +// If an error is provided, its details will be appended to the text message. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultErrorFromErr(text string, err error) *CallToolResult { + if err != nil { + text = fmt.Sprintf("%s: %v", text, err) + } + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: text, + }, + }, + IsError: true, + } +} + +// NewToolResultErrorf creates a new CallToolResult with an error message. +// The error message is formatted using the fmt package. +// Any errors that originate from the tool SHOULD be reported inside the result object. +func NewToolResultErrorf(format string, a ...any) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: ContentTypeText, + Text: fmt.Sprintf(format, a...), + }, + }, + IsError: true, + } +} + +// NewListResourcesResult creates a new ListResourcesResult +func NewListResourcesResult( + resources []Resource, + nextCursor Cursor, +) *ListResourcesResult { + return &ListResourcesResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Resources: resources, + } +} + +// NewListResourceTemplatesResult creates a new ListResourceTemplatesResult +func NewListResourceTemplatesResult( + templates []ResourceTemplate, + nextCursor Cursor, +) *ListResourceTemplatesResult { + return &ListResourceTemplatesResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + ResourceTemplates: templates, + } +} + +// NewReadResourceResult creates a new ReadResourceResult with text content +func NewReadResourceResult(text string) *ReadResourceResult { + return &ReadResourceResult{ + Contents: []ResourceContents{ + TextResourceContents{ + Text: text, + }, + }, + } +} + +// NewListPromptsResult creates a new ListPromptsResult +func NewListPromptsResult( + prompts []Prompt, + nextCursor Cursor, +) *ListPromptsResult { + return &ListPromptsResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Prompts: prompts, + } +} + +// NewGetPromptResult creates a new GetPromptResult +func NewGetPromptResult( + description string, + messages []PromptMessage, +) *GetPromptResult { + return &GetPromptResult{ + Description: description, + Messages: messages, + } +} + +// NewListToolsResult creates a new ListToolsResult +func NewListToolsResult(tools []Tool, nextCursor Cursor) *ListToolsResult { + return &ListToolsResult{ + PaginatedResult: PaginatedResult{ + NextCursor: nextCursor, + }, + Tools: tools, + } +} + +// NewInitializeResult creates a new InitializeResult +func NewInitializeResult( + protocolVersion string, + capabilities ServerCapabilities, + serverInfo Implementation, + instructions string, +) *InitializeResult { + return &InitializeResult{ + ProtocolVersion: protocolVersion, + Capabilities: capabilities, + ServerInfo: serverInfo, + Instructions: instructions, + } +} + +// FormatNumberResult +// Helper for formatting numbers in tool results +func FormatNumberResult(value float64) *CallToolResult { + return NewToolResultText(fmt.Sprintf("%.2f", value)) +} + +func ExtractString(data map[string]any, key string) string { + if value, ok := data[key]; ok { + if str, ok := value.(string); ok { + return str + } + } + return "" +} + +func ExtractMap(data map[string]any, key string) map[string]any { + if value, ok := data[key]; ok { + if m, ok := value.(map[string]any); ok { + return m + } + } + return nil +} + +func ParseContent(contentMap map[string]any) (Content, error) { + contentType := ExtractString(contentMap, "type") + + switch contentType { + case ContentTypeText: + text := ExtractString(contentMap, "text") + return NewTextContent(text), nil + + case ContentTypeImage: + data := ExtractString(contentMap, "data") + mimeType := ExtractString(contentMap, "mimeType") + if data == "" || mimeType == "" { + return nil, fmt.Errorf("image data or mimeType is missing") + } + return NewImageContent(data, mimeType), nil + + case ContentTypeAudio: + data := ExtractString(contentMap, "data") + mimeType := ExtractString(contentMap, "mimeType") + if data == "" || mimeType == "" { + return nil, fmt.Errorf("audio data or mimeType is missing") + } + return NewAudioContent(data, mimeType), nil + + case ContentTypeLink: + uri := ExtractString(contentMap, "uri") + name := ExtractString(contentMap, "name") + description := ExtractString(contentMap, "description") + mimeType := ExtractString(contentMap, "mimeType") + if uri == "" || name == "" { + return nil, fmt.Errorf("resource_link uri or name is missing") + } + return NewResourceLink(uri, name, description, mimeType), nil + + case ContentTypeResource: + resourceMap := ExtractMap(contentMap, "resource") + if resourceMap == nil { + return nil, fmt.Errorf("resource is missing") + } + + resourceContents, err := ParseResourceContents(resourceMap) + if err != nil { + return nil, err + } + + return NewEmbeddedResource(resourceContents), nil + } + + return nil, fmt.Errorf("unsupported content type: %s", contentType) +} + +func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) { + if rawMessage == nil { + return nil, fmt.Errorf("response is nil") + } + + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + result := GetPromptResult{} + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = NewMetaFromMap(metaMap) + } + } + + description, ok := jsonContent["description"] + if ok { + if descriptionStr, ok := description.(string); ok { + result.Description = descriptionStr + } + } + + messages, ok := jsonContent["messages"] + if ok { + messagesArr, ok := messages.([]any) + if !ok { + return nil, fmt.Errorf("messages is not an array") + } + + for _, message := range messagesArr { + messageMap, ok := message.(map[string]any) + if !ok { + return nil, fmt.Errorf("message is not an object") + } + + // Extract role + roleStr := ExtractString(messageMap, "role") + if roleStr == "" || (roleStr != string(RoleAssistant) && roleStr != string(RoleUser)) { + return nil, fmt.Errorf("unsupported role: %s", roleStr) + } + + // Extract content + contentMap, ok := messageMap["content"].(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseContent(contentMap) + if err != nil { + return nil, err + } + + // Append processed message + result.Messages = append(result.Messages, NewPromptMessage(Role(roleStr), content)) + + } + } + + return &result, nil +} + +func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) { + if rawMessage == nil { + return nil, fmt.Errorf("response is nil") + } + + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + var result CallToolResult + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = NewMetaFromMap(metaMap) + } + } + + isError, ok := jsonContent["isError"] + if ok { + if isErrorBool, ok := isError.(bool); ok { + result.IsError = isErrorBool + } + } + + contents, ok := jsonContent["content"] + if !ok { + return nil, fmt.Errorf("content is missing") + } + + contentArr, ok := contents.([]any) + if !ok { + return nil, fmt.Errorf("content is not an array") + } + + for _, content := range contentArr { + // Extract content + contentMap, ok := content.(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseContent(contentMap) + if err != nil { + return nil, err + } + + result.Content = append(result.Content, content) + } + + // Handle structured content + structuredContent, ok := jsonContent["structuredContent"] + if ok { + result.StructuredContent = structuredContent + } + + return &result, nil +} + +func ParseResourceContents(contentMap map[string]any) (ResourceContents, error) { + uri := ExtractString(contentMap, "uri") + if uri == "" { + return nil, fmt.Errorf("resource uri is missing") + } + + mimeType := ExtractString(contentMap, "mimeType") + + if text := ExtractString(contentMap, "text"); text != "" { + return TextResourceContents{ + URI: uri, + MIMEType: mimeType, + Text: text, + }, nil + } + + if blob := ExtractString(contentMap, "blob"); blob != "" { + return BlobResourceContents{ + URI: uri, + MIMEType: mimeType, + Blob: blob, + }, nil + } + + return nil, fmt.Errorf("unsupported resource type") +} + +func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, error) { + if rawMessage == nil { + return nil, fmt.Errorf("response is nil") + } + + var jsonContent map[string]any + if err := json.Unmarshal(*rawMessage, &jsonContent); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + var result ReadResourceResult + + meta, ok := jsonContent["_meta"] + if ok { + if metaMap, ok := meta.(map[string]any); ok { + result.Meta = NewMetaFromMap(metaMap) + } + } + + contents, ok := jsonContent["contents"] + if !ok { + return nil, fmt.Errorf("contents is missing") + } + + contentArr, ok := contents.([]any) + if !ok { + return nil, fmt.Errorf("contents is not an array") + } + + for _, content := range contentArr { + // Extract content + contentMap, ok := content.(map[string]any) + if !ok { + return nil, fmt.Errorf("content is not an object") + } + + // Process content + content, err := ParseResourceContents(contentMap) + if err != nil { + return nil, err + } + + result.Contents = append(result.Contents, content) + } + + return &result, nil +} + +func ParseArgument(request CallToolRequest, key string, defaultVal any) any { + args := request.GetArguments() + if _, ok := args[key]; !ok { + return defaultVal + } else { + return args[key] + } +} + +// ParseBoolean extracts and converts a boolean parameter from a CallToolRequest. +// If the key is not found in the Arguments map, the defaultValue is returned. +// The function uses cast.ToBool for conversion which handles various string representations +// such as "true", "yes", "1", etc. +func ParseBoolean(request CallToolRequest, key string, defaultValue bool) bool { + v := ParseArgument(request, key, defaultValue) + return cast.ToBool(v) +} + +// ParseInt64 extracts and converts an int64 parameter from a CallToolRequest. +// If the key is not found in the Arguments map, the defaultValue is returned. +func ParseInt64(request CallToolRequest, key string, defaultValue int64) int64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt64(v) +} + +// ParseInt32 extracts and converts an int32 parameter from a CallToolRequest. +func ParseInt32(request CallToolRequest, key string, defaultValue int32) int32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt32(v) +} + +// ParseInt16 extracts and converts an int16 parameter from a CallToolRequest. +func ParseInt16(request CallToolRequest, key string, defaultValue int16) int16 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt16(v) +} + +// ParseInt8 extracts and converts an int8 parameter from a CallToolRequest. +func ParseInt8(request CallToolRequest, key string, defaultValue int8) int8 { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt8(v) +} + +// ParseInt extracts and converts an int parameter from a CallToolRequest. +func ParseInt(request CallToolRequest, key string, defaultValue int) int { + v := ParseArgument(request, key, defaultValue) + return cast.ToInt(v) +} + +// ParseUInt extracts and converts an uint parameter from a CallToolRequest. +func ParseUInt(request CallToolRequest, key string, defaultValue uint) uint { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint(v) +} + +// ParseUInt64 extracts and converts an uint64 parameter from a CallToolRequest. +func ParseUInt64(request CallToolRequest, key string, defaultValue uint64) uint64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint64(v) +} + +// ParseUInt32 extracts and converts an uint32 parameter from a CallToolRequest. +func ParseUInt32(request CallToolRequest, key string, defaultValue uint32) uint32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint32(v) +} + +// ParseUInt16 extracts and converts an uint16 parameter from a CallToolRequest. +func ParseUInt16(request CallToolRequest, key string, defaultValue uint16) uint16 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint16(v) +} + +// ParseUInt8 extracts and converts an uint8 parameter from a CallToolRequest. +func ParseUInt8(request CallToolRequest, key string, defaultValue uint8) uint8 { + v := ParseArgument(request, key, defaultValue) + return cast.ToUint8(v) +} + +// ParseFloat32 extracts and converts a float32 parameter from a CallToolRequest. +func ParseFloat32(request CallToolRequest, key string, defaultValue float32) float32 { + v := ParseArgument(request, key, defaultValue) + return cast.ToFloat32(v) +} + +// ParseFloat64 extracts and converts a float64 parameter from a CallToolRequest. +func ParseFloat64(request CallToolRequest, key string, defaultValue float64) float64 { + v := ParseArgument(request, key, defaultValue) + return cast.ToFloat64(v) +} + +// ParseString extracts and converts a string parameter from a CallToolRequest. +func ParseString(request CallToolRequest, key string, defaultValue string) string { + v := ParseArgument(request, key, defaultValue) + return cast.ToString(v) +} + +// ParseStringMap extracts and converts a string map parameter from a CallToolRequest. +func ParseStringMap(request CallToolRequest, key string, defaultValue map[string]any) map[string]any { + v := ParseArgument(request, key, defaultValue) + return cast.ToStringMap(v) +} + +// ToBoolPtr returns a pointer to the given boolean value +func ToBoolPtr(b bool) *bool { + return &b +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/constants.go b/vendor/github.com/mark3labs/mcp-go/server/constants.go new file mode 100644 index 0000000000..e071b2ef45 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/constants.go @@ -0,0 +1,7 @@ +package server + +// Common HTTP header constants used across server transports +const ( + HeaderKeySessionID = "Mcp-Session-Id" + HeaderKeyProtocolVersion = "Mcp-Protocol-Version" +) diff --git a/vendor/github.com/mark3labs/mcp-go/server/ctx.go b/vendor/github.com/mark3labs/mcp-go/server/ctx.go new file mode 100644 index 0000000000..43f01bb68b --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/ctx.go @@ -0,0 +1,8 @@ +package server + +type contextKey int + +const ( + // This const is used as key for context value lookup + requestHeader contextKey = iota +) diff --git a/vendor/github.com/mark3labs/mcp-go/server/errors.go b/vendor/github.com/mark3labs/mcp-go/server/errors.go new file mode 100644 index 0000000000..3864f36f70 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/errors.go @@ -0,0 +1,34 @@ +package server + +import ( + "errors" + "fmt" +) + +var ( + // Common server errors + ErrUnsupported = errors.New("not supported") + ErrResourceNotFound = errors.New("resource not found") + ErrPromptNotFound = errors.New("prompt not found") + ErrToolNotFound = errors.New("tool not found") + + // Session-related errors + ErrSessionNotFound = errors.New("session not found") + ErrSessionExists = errors.New("session already exists") + ErrSessionNotInitialized = errors.New("session not properly initialized") + ErrSessionDoesNotSupportTools = errors.New("session does not support per-session tools") + ErrSessionDoesNotSupportLogging = errors.New("session does not support setting logging level") + + // Notification-related errors + ErrNotificationNotInitialized = errors.New("notification channel not initialized") + ErrNotificationChannelBlocked = errors.New("notification channel queue is full - client may not be processing notifications fast enough") +) + +// ErrDynamicPathConfig is returned when attempting to use static path methods with dynamic path configuration +type ErrDynamicPathConfig struct { + Method string +} + +func (e *ErrDynamicPathConfig) Error() string { + return fmt.Sprintf("%s cannot be used with WithDynamicBasePath. Use dynamic path logic in your router.", e.Method) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/hooks.go b/vendor/github.com/mark3labs/mcp-go/server/hooks.go new file mode 100644 index 0000000000..4baa1c4e05 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/hooks.go @@ -0,0 +1,532 @@ +// Code generated by `go generate`. DO NOT EDIT. +// source: server/internal/gen/hooks.go.tmpl +package server + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// OnRegisterSessionHookFunc is a hook that will be called when a new session is registered. +type OnRegisterSessionHookFunc func(ctx context.Context, session ClientSession) + +// OnUnregisterSessionHookFunc is a hook that will be called when a session is being unregistered. +type OnUnregisterSessionHookFunc func(ctx context.Context, session ClientSession) + +// BeforeAnyHookFunc is a function that is called after the request is +// parsed but before the method is called. +type BeforeAnyHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any) + +// OnSuccessHookFunc is a hook that will be called after the request +// successfully generates a result, but before the result is sent to the client. +type OnSuccessHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) + +// OnErrorHookFunc is a hook that will be called when an error occurs, +// either during the request parsing or the method execution. +// +// Example usage: +// ``` +// +// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { +// // Check for specific error types using errors.Is +// if errors.Is(err, ErrUnsupported) { +// // Handle capability not supported errors +// log.Printf("Capability not supported: %v", err) +// } +// +// // Use errors.As to get specific error types +// var parseErr = &UnparsableMessageError{} +// if errors.As(err, &parseErr) { +// // Access specific methods/fields of the error type +// log.Printf("Failed to parse message for method %s: %v", +// parseErr.GetMethod(), parseErr.Unwrap()) +// // Access the raw message that failed to parse +// rawMsg := parseErr.GetMessage() +// } +// +// // Check for specific resource/prompt/tool errors +// switch { +// case errors.Is(err, ErrResourceNotFound): +// log.Printf("Resource not found: %v", err) +// case errors.Is(err, ErrPromptNotFound): +// log.Printf("Prompt not found: %v", err) +// case errors.Is(err, ErrToolNotFound): +// log.Printf("Tool not found: %v", err) +// } +// }) +type OnErrorHookFunc func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) + +// OnRequestInitializationFunc is a function that called before handle diff request method +// Should any errors arise during func execution, the service will promptly return the corresponding error message. +type OnRequestInitializationFunc func(ctx context.Context, id any, message any) error + +type OnBeforeInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest) +type OnAfterInitializeFunc func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) + +type OnBeforePingFunc func(ctx context.Context, id any, message *mcp.PingRequest) +type OnAfterPingFunc func(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) + +type OnBeforeSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest) +type OnAfterSetLevelFunc func(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) + +type OnBeforeListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest) +type OnAfterListResourcesFunc func(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) + +type OnBeforeListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) +type OnAfterListResourceTemplatesFunc func(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) + +type OnBeforeReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest) +type OnAfterReadResourceFunc func(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) + +type OnBeforeListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest) +type OnAfterListPromptsFunc func(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) + +type OnBeforeGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest) +type OnAfterGetPromptFunc func(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) + +type OnBeforeListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest) +type OnAfterListToolsFunc func(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) + +type OnBeforeCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest) +type OnAfterCallToolFunc func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) + +type Hooks struct { + OnRegisterSession []OnRegisterSessionHookFunc + OnUnregisterSession []OnUnregisterSessionHookFunc + OnBeforeAny []BeforeAnyHookFunc + OnSuccess []OnSuccessHookFunc + OnError []OnErrorHookFunc + OnRequestInitialization []OnRequestInitializationFunc + OnBeforeInitialize []OnBeforeInitializeFunc + OnAfterInitialize []OnAfterInitializeFunc + OnBeforePing []OnBeforePingFunc + OnAfterPing []OnAfterPingFunc + OnBeforeSetLevel []OnBeforeSetLevelFunc + OnAfterSetLevel []OnAfterSetLevelFunc + OnBeforeListResources []OnBeforeListResourcesFunc + OnAfterListResources []OnAfterListResourcesFunc + OnBeforeListResourceTemplates []OnBeforeListResourceTemplatesFunc + OnAfterListResourceTemplates []OnAfterListResourceTemplatesFunc + OnBeforeReadResource []OnBeforeReadResourceFunc + OnAfterReadResource []OnAfterReadResourceFunc + OnBeforeListPrompts []OnBeforeListPromptsFunc + OnAfterListPrompts []OnAfterListPromptsFunc + OnBeforeGetPrompt []OnBeforeGetPromptFunc + OnAfterGetPrompt []OnAfterGetPromptFunc + OnBeforeListTools []OnBeforeListToolsFunc + OnAfterListTools []OnAfterListToolsFunc + OnBeforeCallTool []OnBeforeCallToolFunc + OnAfterCallTool []OnAfterCallToolFunc +} + +func (c *Hooks) AddBeforeAny(hook BeforeAnyHookFunc) { + c.OnBeforeAny = append(c.OnBeforeAny, hook) +} + +func (c *Hooks) AddOnSuccess(hook OnSuccessHookFunc) { + c.OnSuccess = append(c.OnSuccess, hook) +} + +// AddOnError registers a hook function that will be called when an error occurs. +// The error parameter contains the actual error object, which can be interrogated +// using Go's error handling patterns like errors.Is and errors.As. +// +// Example: +// ``` +// // Create a channel to receive errors for testing +// errChan := make(chan error, 1) +// +// // Register hook to capture and inspect errors +// hooks := &Hooks{} +// +// hooks.AddOnError(func(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { +// // For capability-related errors +// if errors.Is(err, ErrUnsupported) { +// // Handle capability not supported +// errChan <- err +// return +// } +// +// // For parsing errors +// var parseErr = &UnparsableMessageError{} +// if errors.As(err, &parseErr) { +// // Handle unparsable message errors +// fmt.Printf("Failed to parse %s request: %v\n", +// parseErr.GetMethod(), parseErr.Unwrap()) +// errChan <- parseErr +// return +// } +// +// // For resource/prompt/tool not found errors +// if errors.Is(err, ErrResourceNotFound) || +// errors.Is(err, ErrPromptNotFound) || +// errors.Is(err, ErrToolNotFound) { +// // Handle not found errors +// errChan <- err +// return +// } +// +// // For other errors +// errChan <- err +// }) +// +// server := NewMCPServer("test-server", "1.0.0", WithHooks(hooks)) +// ``` +func (c *Hooks) AddOnError(hook OnErrorHookFunc) { + c.OnError = append(c.OnError, hook) +} + +func (c *Hooks) beforeAny(ctx context.Context, id any, method mcp.MCPMethod, message any) { + if c == nil { + return + } + for _, hook := range c.OnBeforeAny { + hook(ctx, id, method, message) + } +} + +func (c *Hooks) onSuccess(ctx context.Context, id any, method mcp.MCPMethod, message any, result any) { + if c == nil { + return + } + for _, hook := range c.OnSuccess { + hook(ctx, id, method, message, result) + } +} + +// onError calls all registered error hooks with the error object. +// The err parameter contains the actual error that occurred, which implements +// the standard error interface and may be a wrapped error or custom error type. +// +// This allows consumer code to use Go's error handling patterns: +// - errors.Is(err, ErrUnsupported) to check for specific sentinel errors +// - errors.As(err, &customErr) to extract custom error types +// +// Common error types include: +// - ErrUnsupported: When a capability is not enabled +// - UnparsableMessageError: When request parsing fails +// - ErrResourceNotFound: When a resource is not found +// - ErrPromptNotFound: When a prompt is not found +// - ErrToolNotFound: When a tool is not found +func (c *Hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) { + if c == nil { + return + } + for _, hook := range c.OnError { + hook(ctx, id, method, message, err) + } +} + +func (c *Hooks) AddOnRegisterSession(hook OnRegisterSessionHookFunc) { + c.OnRegisterSession = append(c.OnRegisterSession, hook) +} + +func (c *Hooks) RegisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnRegisterSession { + hook(ctx, session) + } +} + +func (c *Hooks) AddOnUnregisterSession(hook OnUnregisterSessionHookFunc) { + c.OnUnregisterSession = append(c.OnUnregisterSession, hook) +} + +func (c *Hooks) UnregisterSession(ctx context.Context, session ClientSession) { + if c == nil { + return + } + for _, hook := range c.OnUnregisterSession { + hook(ctx, session) + } +} + +func (c *Hooks) AddOnRequestInitialization(hook OnRequestInitializationFunc) { + c.OnRequestInitialization = append(c.OnRequestInitialization, hook) +} + +func (c *Hooks) onRequestInitialization(ctx context.Context, id any, message any) error { + if c == nil { + return nil + } + for _, hook := range c.OnRequestInitialization { + err := hook(ctx, id, message) + if err != nil { + return err + } + } + return nil +} +func (c *Hooks) AddBeforeInitialize(hook OnBeforeInitializeFunc) { + c.OnBeforeInitialize = append(c.OnBeforeInitialize, hook) +} + +func (c *Hooks) AddAfterInitialize(hook OnAfterInitializeFunc) { + c.OnAfterInitialize = append(c.OnAfterInitialize, hook) +} + +func (c *Hooks) beforeInitialize(ctx context.Context, id any, message *mcp.InitializeRequest) { + c.beforeAny(ctx, id, mcp.MethodInitialize, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeInitialize { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterInitialize(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { + c.onSuccess(ctx, id, mcp.MethodInitialize, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterInitialize { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforePing(hook OnBeforePingFunc) { + c.OnBeforePing = append(c.OnBeforePing, hook) +} + +func (c *Hooks) AddAfterPing(hook OnAfterPingFunc) { + c.OnAfterPing = append(c.OnAfterPing, hook) +} + +func (c *Hooks) beforePing(ctx context.Context, id any, message *mcp.PingRequest) { + c.beforeAny(ctx, id, mcp.MethodPing, message) + if c == nil { + return + } + for _, hook := range c.OnBeforePing { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterPing(ctx context.Context, id any, message *mcp.PingRequest, result *mcp.EmptyResult) { + c.onSuccess(ctx, id, mcp.MethodPing, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterPing { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeSetLevel(hook OnBeforeSetLevelFunc) { + c.OnBeforeSetLevel = append(c.OnBeforeSetLevel, hook) +} + +func (c *Hooks) AddAfterSetLevel(hook OnAfterSetLevelFunc) { + c.OnAfterSetLevel = append(c.OnAfterSetLevel, hook) +} + +func (c *Hooks) beforeSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest) { + c.beforeAny(ctx, id, mcp.MethodSetLogLevel, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeSetLevel { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterSetLevel(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) { + c.onSuccess(ctx, id, mcp.MethodSetLogLevel, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterSetLevel { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListResources(hook OnBeforeListResourcesFunc) { + c.OnBeforeListResources = append(c.OnBeforeListResources, hook) +} + +func (c *Hooks) AddAfterListResources(hook OnAfterListResourcesFunc) { + c.OnAfterListResources = append(c.OnAfterListResources, hook) +} + +func (c *Hooks) beforeListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListResources { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListResources(ctx context.Context, id any, message *mcp.ListResourcesRequest, result *mcp.ListResourcesResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListResources { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListResourceTemplates(hook OnBeforeListResourceTemplatesFunc) { + c.OnBeforeListResourceTemplates = append(c.OnBeforeListResourceTemplates, hook) +} + +func (c *Hooks) AddAfterListResourceTemplates(hook OnAfterListResourceTemplatesFunc) { + c.OnAfterListResourceTemplates = append(c.OnAfterListResourceTemplates, hook) +} + +func (c *Hooks) beforeListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesTemplatesList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListResourceTemplates { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListResourceTemplates(ctx context.Context, id any, message *mcp.ListResourceTemplatesRequest, result *mcp.ListResourceTemplatesResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesTemplatesList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListResourceTemplates { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeReadResource(hook OnBeforeReadResourceFunc) { + c.OnBeforeReadResource = append(c.OnBeforeReadResource, hook) +} + +func (c *Hooks) AddAfterReadResource(hook OnAfterReadResourceFunc) { + c.OnAfterReadResource = append(c.OnAfterReadResource, hook) +} + +func (c *Hooks) beforeReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest) { + c.beforeAny(ctx, id, mcp.MethodResourcesRead, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeReadResource { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterReadResource(ctx context.Context, id any, message *mcp.ReadResourceRequest, result *mcp.ReadResourceResult) { + c.onSuccess(ctx, id, mcp.MethodResourcesRead, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterReadResource { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListPrompts(hook OnBeforeListPromptsFunc) { + c.OnBeforeListPrompts = append(c.OnBeforeListPrompts, hook) +} + +func (c *Hooks) AddAfterListPrompts(hook OnAfterListPromptsFunc) { + c.OnAfterListPrompts = append(c.OnAfterListPrompts, hook) +} + +func (c *Hooks) beforeListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest) { + c.beforeAny(ctx, id, mcp.MethodPromptsList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListPrompts { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListPrompts(ctx context.Context, id any, message *mcp.ListPromptsRequest, result *mcp.ListPromptsResult) { + c.onSuccess(ctx, id, mcp.MethodPromptsList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListPrompts { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeGetPrompt(hook OnBeforeGetPromptFunc) { + c.OnBeforeGetPrompt = append(c.OnBeforeGetPrompt, hook) +} + +func (c *Hooks) AddAfterGetPrompt(hook OnAfterGetPromptFunc) { + c.OnAfterGetPrompt = append(c.OnAfterGetPrompt, hook) +} + +func (c *Hooks) beforeGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest) { + c.beforeAny(ctx, id, mcp.MethodPromptsGet, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeGetPrompt { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterGetPrompt(ctx context.Context, id any, message *mcp.GetPromptRequest, result *mcp.GetPromptResult) { + c.onSuccess(ctx, id, mcp.MethodPromptsGet, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterGetPrompt { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeListTools(hook OnBeforeListToolsFunc) { + c.OnBeforeListTools = append(c.OnBeforeListTools, hook) +} + +func (c *Hooks) AddAfterListTools(hook OnAfterListToolsFunc) { + c.OnAfterListTools = append(c.OnAfterListTools, hook) +} + +func (c *Hooks) beforeListTools(ctx context.Context, id any, message *mcp.ListToolsRequest) { + c.beforeAny(ctx, id, mcp.MethodToolsList, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeListTools { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterListTools(ctx context.Context, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) { + c.onSuccess(ctx, id, mcp.MethodToolsList, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterListTools { + hook(ctx, id, message, result) + } +} +func (c *Hooks) AddBeforeCallTool(hook OnBeforeCallToolFunc) { + c.OnBeforeCallTool = append(c.OnBeforeCallTool, hook) +} + +func (c *Hooks) AddAfterCallTool(hook OnAfterCallToolFunc) { + c.OnAfterCallTool = append(c.OnAfterCallTool, hook) +} + +func (c *Hooks) beforeCallTool(ctx context.Context, id any, message *mcp.CallToolRequest) { + c.beforeAny(ctx, id, mcp.MethodToolsCall, message) + if c == nil { + return + } + for _, hook := range c.OnBeforeCallTool { + hook(ctx, id, message) + } +} + +func (c *Hooks) afterCallTool(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { + c.onSuccess(ctx, id, mcp.MethodToolsCall, message, result) + if c == nil { + return + } + for _, hook := range c.OnAfterCallTool { + hook(ctx, id, message, result) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go b/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go new file mode 100644 index 0000000000..4f5ad53d0d --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/http_transport_options.go @@ -0,0 +1,11 @@ +package server + +import ( + "context" + "net/http" +) + +// HTTPContextFunc is a function that takes an existing context and the current +// request and returns a potentially modified context based on the request +// content. This can be used to inject context values from headers, for example. +type HTTPContextFunc func(ctx context.Context, r *http.Request) context.Context diff --git a/vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go b/vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go new file mode 100644 index 0000000000..daaf28a5cc --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/inprocess_session.go @@ -0,0 +1,115 @@ +package server + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SamplingHandler defines the interface for handling sampling requests from servers. +type SamplingHandler interface { + CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} + +type InProcessSession struct { + sessionID string + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value + clientCapabilities atomic.Value + samplingHandler SamplingHandler + mu sync.RWMutex +} + +func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InProcessSession { + return &InProcessSession{ + sessionID: sessionID, + notifications: make(chan mcp.JSONRPCNotification, 100), + samplingHandler: samplingHandler, + } +} + +func (s *InProcessSession) SessionID() string { + return s.sessionID +} + +func (s *InProcessSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notifications +} + +func (s *InProcessSession) Initialize() { + s.loggingLevel.Store(mcp.LoggingLevelError) + s.initialized.Store(true) +} + +func (s *InProcessSession) Initialized() bool { + return s.initialized.Load() +} + +func (s *InProcessSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *InProcessSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *InProcessSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *InProcessSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +func (s *InProcessSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *InProcessSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +func (s *InProcessSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + handler := s.samplingHandler + s.mu.RUnlock() + + if handler == nil { + return nil, fmt.Errorf("no sampling handler available") + } + + return handler.CreateMessage(ctx, request) +} + +// GenerateInProcessSessionID generates a unique session ID for inprocess clients +func GenerateInProcessSessionID() string { + return fmt.Sprintf("inprocess-%d", time.Now().UnixNano()) +} + +// Ensure interface compliance +var ( + _ ClientSession = (*InProcessSession)(nil) + _ SessionWithLogging = (*InProcessSession)(nil) + _ SessionWithClientInfo = (*InProcessSession)(nil) + _ SessionWithSampling = (*InProcessSession)(nil) +) diff --git a/vendor/github.com/mark3labs/mcp-go/server/request_handler.go b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go new file mode 100644 index 0000000000..b9175dc4e2 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/request_handler.go @@ -0,0 +1,339 @@ +// Code generated by `go generate`. DO NOT EDIT. +// source: server/internal/gen/request_handler.go.tmpl +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "github.com/mark3labs/mcp-go/mcp" +) + +// HandleMessage processes an incoming JSON-RPC message and returns an appropriate response +func (s *MCPServer) HandleMessage( + ctx context.Context, + message json.RawMessage, +) mcp.JSONRPCMessage { + // Add server to context + ctx = context.WithValue(ctx, serverKey{}, s) + var err *requestError + + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + Method mcp.MCPMethod `json:"method"` + ID any `json:"id,omitempty"` + Result any `json:"result,omitempty"` + } + + if err := json.Unmarshal(message, &baseMessage); err != nil { + return createErrorResponse( + nil, + mcp.PARSE_ERROR, + "Failed to parse message", + ) + } + + // Check for valid JSONRPC version + if baseMessage.JSONRPC != mcp.JSONRPC_VERSION { + return createErrorResponse( + baseMessage.ID, + mcp.INVALID_REQUEST, + "Invalid JSON-RPC version", + ) + } + + if baseMessage.ID == nil { + var notification mcp.JSONRPCNotification + if err := json.Unmarshal(message, ¬ification); err != nil { + return createErrorResponse( + nil, + mcp.PARSE_ERROR, + "Failed to parse notification", + ) + } + s.handleNotification(ctx, notification) + return nil // Return nil for notifications + } + + if baseMessage.Result != nil { + // this is a response to a request sent by the server (e.g. from a ping + // sent due to WithKeepAlive option) + return nil + } + + handleErr := s.hooks.onRequestInitialization(ctx, baseMessage.ID, message) + if handleErr != nil { + return createErrorResponse( + baseMessage.ID, + mcp.INVALID_REQUEST, + handleErr.Error(), + ) + } + + // Get request header from ctx + h := ctx.Value(requestHeader) + headers, ok := h.(http.Header) + + if headers == nil || !ok { + headers = make(http.Header) + } + + switch baseMessage.Method { + case mcp.MethodInitialize: + var request mcp.InitializeRequest + var result *mcp.InitializeResult + if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) + result, err = s.handleInitialize(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterInitialize(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPing: + var request mcp.PingRequest + var result *mcp.EmptyResult + if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforePing(ctx, baseMessage.ID, &request) + result, err = s.handlePing(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterPing(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodSetLogLevel: + var request mcp.SetLevelRequest + var result *mcp.EmptyResult + if s.capabilities.logging == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("logging %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeSetLevel(ctx, baseMessage.ID, &request) + result, err = s.handleSetLevel(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterSetLevel(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesList: + var request mcp.ListResourcesRequest + var result *mcp.ListResourcesResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeListResources(ctx, baseMessage.ID, &request) + result, err = s.handleListResources(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListResources(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesTemplatesList: + var request mcp.ListResourceTemplatesRequest + var result *mcp.ListResourceTemplatesResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) + result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListResourceTemplates(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodResourcesRead: + var request mcp.ReadResourceRequest + var result *mcp.ReadResourceResult + if s.capabilities.resources == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("resources %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) + result, err = s.handleReadResource(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterReadResource(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPromptsList: + var request mcp.ListPromptsRequest + var result *mcp.ListPromptsResult + if s.capabilities.prompts == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("prompts %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) + result, err = s.handleListPrompts(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListPrompts(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodPromptsGet: + var request mcp.GetPromptRequest + var result *mcp.GetPromptResult + if s.capabilities.prompts == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("prompts %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) + result, err = s.handleGetPrompt(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterGetPrompt(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodToolsList: + var request mcp.ListToolsRequest + var result *mcp.ListToolsResult + if s.capabilities.tools == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("tools %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeListTools(ctx, baseMessage.ID, &request) + result, err = s.handleListTools(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterListTools(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + case mcp.MethodToolsCall: + var request mcp.CallToolRequest + var result *mcp.CallToolResult + if s.capabilities.tools == nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.METHOD_NOT_FOUND, + err: fmt.Errorf("tools %w", ErrUnsupported), + } + } else if unmarshalErr := json.Unmarshal(message, &request); unmarshalErr != nil { + err = &requestError{ + id: baseMessage.ID, + code: mcp.INVALID_REQUEST, + err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, + } + } else { + request.Header = headers + s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) + result, err = s.handleToolCall(ctx, baseMessage.ID, request) + } + if err != nil { + s.hooks.onError(ctx, baseMessage.ID, baseMessage.Method, &request, err) + return err.ToJSONRPCError() + } + s.hooks.afterCallTool(ctx, baseMessage.ID, &request, result) + return createResponse(baseMessage.ID, *result) + default: + return createErrorResponse( + baseMessage.ID, + mcp.METHOD_NOT_FOUND, + fmt.Sprintf("Method %s not found", baseMessage.Method), + ) + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/sampling.go b/vendor/github.com/mark3labs/mcp-go/server/sampling.go new file mode 100644 index 0000000000..2118db155a --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/sampling.go @@ -0,0 +1,61 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// EnableSampling enables sampling capabilities for the server. +// This allows the server to send sampling requests to clients that support it. +func (s *MCPServer) EnableSampling() { + s.capabilitiesMu.Lock() + defer s.capabilitiesMu.Unlock() + + enabled := true + s.capabilities.sampling = &enabled +} + +// RequestSampling sends a sampling request to the client. +// The client must have declared sampling capability during initialization. +func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Check if the session supports sampling requests + if samplingSession, ok := session.(SessionWithSampling); ok { + return samplingSession.RequestSampling(ctx, request) + } + + // Check for inprocess sampling handler in context + if handler := InProcessSamplingHandlerFromContext(ctx); handler != nil { + return handler.CreateMessage(ctx, request) + } + + return nil, fmt.Errorf("session does not support sampling") +} + +// SessionWithSampling extends ClientSession to support sampling requests. +type SessionWithSampling interface { + ClientSession + RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} + +// inProcessSamplingHandlerKey is the context key for storing inprocess sampling handler +type inProcessSamplingHandlerKey struct{} + +// WithInProcessSamplingHandler adds a sampling handler to the context for inprocess clients +func WithInProcessSamplingHandler(ctx context.Context, handler SamplingHandler) context.Context { + return context.WithValue(ctx, inProcessSamplingHandlerKey{}, handler) +} + +// InProcessSamplingHandlerFromContext retrieves the inprocess sampling handler from context +func InProcessSamplingHandlerFromContext(ctx context.Context) SamplingHandler { + if handler, ok := ctx.Value(inProcessSamplingHandlerKey{}).(SamplingHandler); ok { + return handler + } + return nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/server.go b/vendor/github.com/mark3labs/mcp-go/server/server.go new file mode 100644 index 0000000000..6883572807 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/server.go @@ -0,0 +1,1201 @@ +// Package server provides MCP (Model Context Protocol) server implementations. +package server + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "slices" + "sort" + "sync" + + "github.com/mark3labs/mcp-go/mcp" +) + +// resourceEntry holds both a resource and its handler +type resourceEntry struct { + resource mcp.Resource + handler ResourceHandlerFunc +} + +// resourceTemplateEntry holds both a template and its handler +type resourceTemplateEntry struct { + template mcp.ResourceTemplate + handler ResourceTemplateHandlerFunc +} + +// ServerOption is a function that configures an MCPServer. +type ServerOption func(*MCPServer) + +// ResourceHandlerFunc is a function that returns resource contents. +type ResourceHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) + +// ResourceTemplateHandlerFunc is a function that returns a resource template. +type ResourceTemplateHandlerFunc func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) + +// PromptHandlerFunc handles prompt requests with given arguments. +type PromptHandlerFunc func(ctx context.Context, request mcp.GetPromptRequest) (*mcp.GetPromptResult, error) + +// ToolHandlerFunc handles tool calls with given arguments. +type ToolHandlerFunc func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) + +// ToolHandlerMiddleware is a middleware function that wraps a ToolHandlerFunc. +type ToolHandlerMiddleware func(ToolHandlerFunc) ToolHandlerFunc + +// ResourceHandlerMiddleware is a middleware function that wraps a ResourceHandlerFunc. +type ResourceHandlerMiddleware func(ResourceHandlerFunc) ResourceHandlerFunc + +// ToolFilterFunc is a function that filters tools based on context, typically using session information. +type ToolFilterFunc func(ctx context.Context, tools []mcp.Tool) []mcp.Tool + +// ServerTool combines a Tool with its ToolHandlerFunc. +type ServerTool struct { + Tool mcp.Tool + Handler ToolHandlerFunc +} + +// ServerPrompt combines a Prompt with its handler function. +type ServerPrompt struct { + Prompt mcp.Prompt + Handler PromptHandlerFunc +} + +// ServerResource combines a Resource with its handler function. +type ServerResource struct { + Resource mcp.Resource + Handler ResourceHandlerFunc +} + +// ServerResourceTemplate combines a ResourceTemplate with its handler function. +type ServerResourceTemplate struct { + Template mcp.ResourceTemplate + Handler ResourceTemplateHandlerFunc +} + +// serverKey is the context key for storing the server instance +type serverKey struct{} + +// ServerFromContext retrieves the MCPServer instance from a context +func ServerFromContext(ctx context.Context) *MCPServer { + if srv, ok := ctx.Value(serverKey{}).(*MCPServer); ok { + return srv + } + return nil +} + +// UnparsableMessageError is attached to the RequestError when json.Unmarshal +// fails on the request. +type UnparsableMessageError struct { + message json.RawMessage + method mcp.MCPMethod + err error +} + +func (e *UnparsableMessageError) Error() string { + return fmt.Sprintf("unparsable %s request: %s", e.method, e.err) +} + +func (e *UnparsableMessageError) Unwrap() error { + return e.err +} + +func (e *UnparsableMessageError) GetMessage() json.RawMessage { + return e.message +} + +func (e *UnparsableMessageError) GetMethod() mcp.MCPMethod { + return e.method +} + +// RequestError is an error that can be converted to a JSON-RPC error. +// Implements Unwrap() to allow inspecting the error chain. +type requestError struct { + id any + code int + err error +} + +func (e *requestError) Error() string { + return fmt.Sprintf("request error: %s", e.err) +} + +func (e *requestError) ToJSONRPCError() mcp.JSONRPCError { + return mcp.JSONRPCError{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(e.id), + Error: struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` + }{ + Code: e.code, + Message: e.err.Error(), + }, + } +} + +func (e *requestError) Unwrap() error { + return e.err +} + +// NotificationHandlerFunc handles incoming notifications. +type NotificationHandlerFunc func(ctx context.Context, notification mcp.JSONRPCNotification) + +// MCPServer implements a Model Context Protocol server that can handle various types of requests +// including resources, prompts, and tools. +type MCPServer struct { + // Separate mutexes for different resource types + resourcesMu sync.RWMutex + promptsMu sync.RWMutex + toolsMu sync.RWMutex + middlewareMu sync.RWMutex + notificationHandlersMu sync.RWMutex + capabilitiesMu sync.RWMutex + toolFiltersMu sync.RWMutex + + name string + version string + instructions string + resources map[string]resourceEntry + resourceTemplates map[string]resourceTemplateEntry + prompts map[string]mcp.Prompt + promptHandlers map[string]PromptHandlerFunc + tools map[string]ServerTool + toolHandlerMiddlewares []ToolHandlerMiddleware + resourceHandlerMiddlewares []ResourceHandlerMiddleware + toolFilters []ToolFilterFunc + notificationHandlers map[string]NotificationHandlerFunc + capabilities serverCapabilities + paginationLimit *int + sessions sync.Map + hooks *Hooks +} + +// WithPaginationLimit sets the pagination limit for the server. +func WithPaginationLimit(limit int) ServerOption { + return func(s *MCPServer) { + s.paginationLimit = &limit + } +} + +// serverCapabilities defines the supported features of the MCP server +type serverCapabilities struct { + tools *toolCapabilities + resources *resourceCapabilities + prompts *promptCapabilities + logging *bool + sampling *bool +} + +// resourceCapabilities defines the supported resource-related features +type resourceCapabilities struct { + subscribe bool + listChanged bool +} + +// promptCapabilities defines the supported prompt-related features +type promptCapabilities struct { + listChanged bool +} + +// toolCapabilities defines the supported tool-related features +type toolCapabilities struct { + listChanged bool +} + +// WithResourceCapabilities configures resource-related server capabilities +func WithResourceCapabilities(subscribe, listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.resources = &resourceCapabilities{ + subscribe: subscribe, + listChanged: listChanged, + } + } +} + +// WithToolHandlerMiddleware allows adding a middleware for the +// tool handler call chain. +func WithToolHandlerMiddleware( + toolHandlerMiddleware ToolHandlerMiddleware, +) ServerOption { + return func(s *MCPServer) { + s.middlewareMu.Lock() + s.toolHandlerMiddlewares = append(s.toolHandlerMiddlewares, toolHandlerMiddleware) + s.middlewareMu.Unlock() + } +} + +// WithResourceHandlerMiddleware allows adding a middleware for the +// resource handler call chain. +func WithResourceHandlerMiddleware( + resourceHandlerMiddleware ResourceHandlerMiddleware, +) ServerOption { + return func(s *MCPServer) { + s.middlewareMu.Lock() + s.resourceHandlerMiddlewares = append(s.resourceHandlerMiddlewares, resourceHandlerMiddleware) + s.middlewareMu.Unlock() + } +} + +// WithResourceRecovery adds a middleware that recovers from panics in resource handlers. +func WithResourceRecovery() ServerOption { + return WithResourceHandlerMiddleware(func(next ResourceHandlerFunc) ResourceHandlerFunc { + return func(ctx context.Context, request mcp.ReadResourceRequest) (result []mcp.ResourceContents, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf( + "panic recovered in %s resource handler: %v", + request.Params.URI, + r, + ) + } + }() + return next(ctx, request) + } + }) +} + +// WithToolFilter adds a filter function that will be applied to tools before they are returned in list_tools +func WithToolFilter( + toolFilter ToolFilterFunc, +) ServerOption { + return func(s *MCPServer) { + s.toolFiltersMu.Lock() + s.toolFilters = append(s.toolFilters, toolFilter) + s.toolFiltersMu.Unlock() + } +} + +// WithRecovery adds a middleware that recovers from panics in tool handlers. +func WithRecovery() ServerOption { + return WithToolHandlerMiddleware(func(next ToolHandlerFunc) ToolHandlerFunc { + return func(ctx context.Context, request mcp.CallToolRequest) (result *mcp.CallToolResult, err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf( + "panic recovered in %s tool handler: %v", + request.Params.Name, + r, + ) + } + }() + return next(ctx, request) + } + }) +} + +// WithHooks allows adding hooks that will be called before or after +// either [all] requests or before / after specific request methods, or else +// prior to returning an error to the client. +func WithHooks(hooks *Hooks) ServerOption { + return func(s *MCPServer) { + s.hooks = hooks + } +} + +// WithPromptCapabilities configures prompt-related server capabilities +func WithPromptCapabilities(listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.prompts = &promptCapabilities{ + listChanged: listChanged, + } + } +} + +// WithToolCapabilities configures tool-related server capabilities +func WithToolCapabilities(listChanged bool) ServerOption { + return func(s *MCPServer) { + // Always create a non-nil capability object + s.capabilities.tools = &toolCapabilities{ + listChanged: listChanged, + } + } +} + +// WithLogging enables logging capabilities for the server +func WithLogging() ServerOption { + return func(s *MCPServer) { + s.capabilities.logging = mcp.ToBoolPtr(true) + } +} + +// WithInstructions sets the server instructions for the client returned in the initialize response +func WithInstructions(instructions string) ServerOption { + return func(s *MCPServer) { + s.instructions = instructions + } +} + +// NewMCPServer creates a new MCP server instance with the given name, version and options +func NewMCPServer( + name, version string, + opts ...ServerOption, +) *MCPServer { + s := &MCPServer{ + resources: make(map[string]resourceEntry), + resourceTemplates: make(map[string]resourceTemplateEntry), + prompts: make(map[string]mcp.Prompt), + promptHandlers: make(map[string]PromptHandlerFunc), + tools: make(map[string]ServerTool), + toolHandlerMiddlewares: make([]ToolHandlerMiddleware, 0), + resourceHandlerMiddlewares: make([]ResourceHandlerMiddleware, 0), + name: name, + version: version, + notificationHandlers: make(map[string]NotificationHandlerFunc), + capabilities: serverCapabilities{ + tools: nil, + resources: nil, + prompts: nil, + logging: nil, + }, + } + + for _, opt := range opts { + opt(s) + } + + return s +} + +// GenerateInProcessSessionID generates a unique session ID for inprocess clients +func (s *MCPServer) GenerateInProcessSessionID() string { + return GenerateInProcessSessionID() +} + +// AddResources registers multiple resources at once +func (s *MCPServer) AddResources(resources ...ServerResource) { + s.implicitlyRegisterResourceCapabilities() + + s.resourcesMu.Lock() + for _, entry := range resources { + s.resources[entry.Resource.URI] = resourceEntry{ + resource: entry.Resource, + handler: entry.Handler, + } + } + s.resourcesMu.Unlock() + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification + if s.capabilities.resources.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + +// SetResources replaces all existing resources with the provided list +func (s *MCPServer) SetResources(resources ...ServerResource) { + s.resourcesMu.Lock() + s.resources = make(map[string]resourceEntry, len(resources)) + s.resourcesMu.Unlock() + s.AddResources(resources...) +} + +// AddResource registers a new resource and its handler +func (s *MCPServer) AddResource( + resource mcp.Resource, + handler ResourceHandlerFunc, +) { + s.AddResources(ServerResource{Resource: resource, Handler: handler}) +} + +// DeleteResources removes resources from the server +func (s *MCPServer) DeleteResources(uris ...string) { + s.resourcesMu.Lock() + var exists bool + for _, uri := range uris { + if _, ok := s.resources[uri]; ok { + delete(s.resources, uri) + exists = true + } + } + s.resourcesMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource + if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + +// RemoveResource removes a resource from the server +func (s *MCPServer) RemoveResource(uri string) { + s.resourcesMu.Lock() + _, exists := s.resources[uri] + if exists { + delete(s.resources, uri) + } + s.resourcesMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled and we actually remove a resource + if exists && s.capabilities.resources != nil && s.capabilities.resources.listChanged { + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + +// AddResourceTemplates registers multiple resource templates at once +func (s *MCPServer) AddResourceTemplates(resourceTemplates ...ServerResourceTemplate) { + s.implicitlyRegisterResourceCapabilities() + + s.resourcesMu.Lock() + for _, entry := range resourceTemplates { + s.resourceTemplates[entry.Template.URITemplate.Raw()] = resourceTemplateEntry{ + template: entry.Template, + handler: entry.Handler, + } + } + s.resourcesMu.Unlock() + + // When the list of available resources changes, servers that declared the listChanged capability SHOULD send a notification + if s.capabilities.resources.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationResourcesListChanged, nil) + } +} + +// SetResourceTemplates replaces all existing resource templates with the provided list +func (s *MCPServer) SetResourceTemplates(templates ...ServerResourceTemplate) { + s.resourcesMu.Lock() + s.resourceTemplates = make(map[string]resourceTemplateEntry, len(templates)) + s.resourcesMu.Unlock() + s.AddResourceTemplates(templates...) +} + +// AddResourceTemplate registers a new resource template and its handler +func (s *MCPServer) AddResourceTemplate( + template mcp.ResourceTemplate, + handler ResourceTemplateHandlerFunc, +) { + s.AddResourceTemplates(ServerResourceTemplate{Template: template, Handler: handler}) +} + +// AddPrompts registers multiple prompts at once +func (s *MCPServer) AddPrompts(prompts ...ServerPrompt) { + s.implicitlyRegisterPromptCapabilities() + + s.promptsMu.Lock() + for _, entry := range prompts { + s.prompts[entry.Prompt.Name] = entry.Prompt + s.promptHandlers[entry.Prompt.Name] = entry.Handler + } + s.promptsMu.Unlock() + + // When the list of available prompts changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.prompts.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + } +} + +// AddPrompt registers a new prompt handler with the given name +func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { + s.AddPrompts(ServerPrompt{Prompt: prompt, Handler: handler}) +} + +// SetPrompts replaces all existing prompts with the provided list +func (s *MCPServer) SetPrompts(prompts ...ServerPrompt) { + s.promptsMu.Lock() + s.prompts = make(map[string]mcp.Prompt, len(prompts)) + s.promptHandlers = make(map[string]PromptHandlerFunc, len(prompts)) + s.promptsMu.Unlock() + s.AddPrompts(prompts...) +} + +// DeletePrompts removes prompts from the server +func (s *MCPServer) DeletePrompts(names ...string) { + s.promptsMu.Lock() + var exists bool + for _, name := range names { + if _, ok := s.prompts[name]; ok { + delete(s.prompts, name) + delete(s.promptHandlers, name) + exists = true + } + } + s.promptsMu.Unlock() + + // Send notification to all initialized sessions if listChanged capability is enabled, and we actually remove a prompt + if exists && s.capabilities.prompts != nil && s.capabilities.prompts.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationPromptsListChanged, nil) + } +} + +// AddTool registers a new tool and its handler +func (s *MCPServer) AddTool(tool mcp.Tool, handler ToolHandlerFunc) { + s.AddTools(ServerTool{Tool: tool, Handler: handler}) +} + +// Register tool capabilities due to a tool being added. Default to +// listChanged: true, but don't change the value if we've already explicitly +// registered tools.listChanged false. +func (s *MCPServer) implicitlyRegisterToolCapabilities() { + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.tools != nil }, + func() { s.capabilities.tools = &toolCapabilities{listChanged: true} }, + ) +} + +func (s *MCPServer) implicitlyRegisterResourceCapabilities() { + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.resources != nil }, + func() { s.capabilities.resources = &resourceCapabilities{} }, + ) +} + +func (s *MCPServer) implicitlyRegisterPromptCapabilities() { + s.implicitlyRegisterCapabilities( + func() bool { return s.capabilities.prompts != nil }, + func() { s.capabilities.prompts = &promptCapabilities{} }, + ) +} + +func (s *MCPServer) implicitlyRegisterCapabilities(check func() bool, register func()) { + s.capabilitiesMu.RLock() + if check() { + s.capabilitiesMu.RUnlock() + return + } + s.capabilitiesMu.RUnlock() + + s.capabilitiesMu.Lock() + if !check() { + register() + } + s.capabilitiesMu.Unlock() +} + +// AddTools registers multiple tools at once +func (s *MCPServer) AddTools(tools ...ServerTool) { + s.implicitlyRegisterToolCapabilities() + + s.toolsMu.Lock() + for _, entry := range tools { + s.tools[entry.Tool.Name] = entry + } + s.toolsMu.Unlock() + + // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. + if s.capabilities.tools.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + } +} + +// SetTools replaces all existing tools with the provided list +func (s *MCPServer) SetTools(tools ...ServerTool) { + s.toolsMu.Lock() + s.tools = make(map[string]ServerTool, len(tools)) + s.toolsMu.Unlock() + s.AddTools(tools...) +} + +// DeleteTools removes tools from the server +func (s *MCPServer) DeleteTools(names ...string) { + s.toolsMu.Lock() + var exists bool + for _, name := range names { + if _, ok := s.tools[name]; ok { + delete(s.tools, name) + exists = true + } + } + s.toolsMu.Unlock() + + // When the list of available tools changes, servers that declared the listChanged capability SHOULD send a notification. + if exists && s.capabilities.tools != nil && s.capabilities.tools.listChanged { + // Send notification to all initialized sessions + s.SendNotificationToAllClients(mcp.MethodNotificationToolsListChanged, nil) + } +} + +// AddNotificationHandler registers a new handler for incoming notifications +func (s *MCPServer) AddNotificationHandler( + method string, + handler NotificationHandlerFunc, +) { + s.notificationHandlersMu.Lock() + defer s.notificationHandlersMu.Unlock() + s.notificationHandlers[method] = handler +} + +func (s *MCPServer) handleInitialize( + ctx context.Context, + _ any, + request mcp.InitializeRequest, +) (*mcp.InitializeResult, *requestError) { + capabilities := mcp.ServerCapabilities{} + + // Only add resource capabilities if they're configured + if s.capabilities.resources != nil { + capabilities.Resources = &struct { + Subscribe bool `json:"subscribe,omitempty"` + ListChanged bool `json:"listChanged,omitempty"` + }{ + Subscribe: s.capabilities.resources.subscribe, + ListChanged: s.capabilities.resources.listChanged, + } + } + + // Only add prompt capabilities if they're configured + if s.capabilities.prompts != nil { + capabilities.Prompts = &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: s.capabilities.prompts.listChanged, + } + } + + // Only add tool capabilities if they're configured + if s.capabilities.tools != nil { + capabilities.Tools = &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: s.capabilities.tools.listChanged, + } + } + + if s.capabilities.logging != nil && *s.capabilities.logging { + capabilities.Logging = &struct{}{} + } + + if s.capabilities.sampling != nil && *s.capabilities.sampling { + capabilities.Sampling = &struct{}{} + } + + result := mcp.InitializeResult{ + ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), + ServerInfo: mcp.Implementation{ + Name: s.name, + Version: s.version, + }, + Capabilities: capabilities, + Instructions: s.instructions, + } + + if session := ClientSessionFromContext(ctx); session != nil { + session.Initialize() + + // Store client info if the session supports it + if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok { + sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo) + sessionWithClientInfo.SetClientCapabilities(request.Params.Capabilities) + } + } + + return &result, nil +} + +func (s *MCPServer) protocolVersion(clientVersion string) string { + // For backwards compatibility, if the server does not receive an MCP-Protocol-Version header, + // and has no other way to identify the version - for example, by relying on the protocol version negotiated + // during initialization - the server SHOULD assume protocol version 2025-03-26 + // https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + if len(clientVersion) == 0 { + clientVersion = "2025-03-26" + } + + if slices.Contains(mcp.ValidProtocolVersions, clientVersion) { + return clientVersion + } + + return mcp.LATEST_PROTOCOL_VERSION +} + +func (s *MCPServer) handlePing( + _ context.Context, + _ any, + _ mcp.PingRequest, +) (*mcp.EmptyResult, *requestError) { + return &mcp.EmptyResult{}, nil +} + +func (s *MCPServer) handleSetLevel( + ctx context.Context, + id any, + request mcp.SetLevelRequest, +) (*mcp.EmptyResult, *requestError) { + clientSession := ClientSessionFromContext(ctx) + if clientSession == nil || !clientSession.Initialized() { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: ErrSessionNotInitialized, + } + } + + sessionLogging, ok := clientSession.(SessionWithLogging) + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: ErrSessionDoesNotSupportLogging, + } + } + + level := request.Params.Level + // Validate logging level + switch level { + case mcp.LoggingLevelDebug, mcp.LoggingLevelInfo, mcp.LoggingLevelNotice, + mcp.LoggingLevelWarning, mcp.LoggingLevelError, mcp.LoggingLevelCritical, + mcp.LoggingLevelAlert, mcp.LoggingLevelEmergency: + // Valid level + default: + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("invalid logging level '%s'", level), + } + } + + sessionLogging.SetLogLevel(level) + + return &mcp.EmptyResult{}, nil +} + +func listByPagination[T mcp.Named]( + _ context.Context, + s *MCPServer, + cursor mcp.Cursor, + allElements []T, +) ([]T, mcp.Cursor, error) { + startPos := 0 + if cursor != "" { + c, err := base64.StdEncoding.DecodeString(string(cursor)) + if err != nil { + return nil, "", err + } + cString := string(c) + startPos = sort.Search(len(allElements), func(i int) bool { + return allElements[i].GetName() > cString + }) + } + endPos := len(allElements) + if s.paginationLimit != nil { + if len(allElements) > startPos+*s.paginationLimit { + endPos = startPos + *s.paginationLimit + } + } + elementsToReturn := allElements[startPos:endPos] + // set the next cursor + nextCursor := func() mcp.Cursor { + if s.paginationLimit != nil && len(elementsToReturn) >= *s.paginationLimit { + nc := elementsToReturn[len(elementsToReturn)-1].GetName() + toString := base64.StdEncoding.EncodeToString([]byte(nc)) + return mcp.Cursor(toString) + } + return "" + }() + return elementsToReturn, nextCursor, nil +} + +func (s *MCPServer) handleListResources( + ctx context.Context, + id any, + request mcp.ListResourcesRequest, +) (*mcp.ListResourcesResult, *requestError) { + s.resourcesMu.RLock() + resources := make([]mcp.Resource, 0, len(s.resources)) + for _, entry := range s.resources { + resources = append(resources, entry.resource) + } + s.resourcesMu.RUnlock() + + // Sort the resources by name + sort.Slice(resources, func(i, j int) bool { + return resources[i].Name < resources[j].Name + }) + resourcesToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + resources, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } + } + result := mcp.ListResourcesResult{ + Resources: resourcesToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, + } + return &result, nil +} + +func (s *MCPServer) handleListResourceTemplates( + ctx context.Context, + id any, + request mcp.ListResourceTemplatesRequest, +) (*mcp.ListResourceTemplatesResult, *requestError) { + s.resourcesMu.RLock() + templates := make([]mcp.ResourceTemplate, 0, len(s.resourceTemplates)) + for _, entry := range s.resourceTemplates { + templates = append(templates, entry.template) + } + s.resourcesMu.RUnlock() + sort.Slice(templates, func(i, j int) bool { + return templates[i].Name < templates[j].Name + }) + templatesToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + templates, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } + } + result := mcp.ListResourceTemplatesResult{ + ResourceTemplates: templatesToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, + } + return &result, nil +} + +func (s *MCPServer) handleReadResource( + ctx context.Context, + id any, + request mcp.ReadResourceRequest, +) (*mcp.ReadResourceResult, *requestError) { + s.resourcesMu.RLock() + // First try direct resource handlers + if entry, ok := s.resources[request.Params.URI]; ok { + handler := entry.handler + s.resourcesMu.RUnlock() + + finalHandler := handler + s.middlewareMu.RLock() + mw := s.resourceHandlerMiddlewares + // Apply middlewares in reverse order + for i := len(mw) - 1; i >= 0; i-- { + finalHandler = mw[i](finalHandler) + } + s.middlewareMu.RUnlock() + + contents, err := finalHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return &mcp.ReadResourceResult{Contents: contents}, nil + } + + // If no direct handler found, try matching against templates + var matchedHandler ResourceTemplateHandlerFunc + var matched bool + for _, entry := range s.resourceTemplates { + template := entry.template + if matchesTemplate(request.Params.URI, template.URITemplate) { + matchedHandler = entry.handler + matched = true + matchedVars := template.URITemplate.Match(request.Params.URI) + // Convert matched variables to a map + request.Params.Arguments = make(map[string]any, len(matchedVars)) + for name, value := range matchedVars { + request.Params.Arguments[name] = value.V + } + break + } + } + s.resourcesMu.RUnlock() + + if matched { + contents, err := matchedHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + return &mcp.ReadResourceResult{Contents: contents}, nil + } + + return nil, &requestError{ + id: id, + code: mcp.RESOURCE_NOT_FOUND, + err: fmt.Errorf( + "handler not found for resource URI '%s': %w", + request.Params.URI, + ErrResourceNotFound, + ), + } +} + +// matchesTemplate checks if a URI matches a URI template pattern +func matchesTemplate(uri string, template *mcp.URITemplate) bool { + return template.Regexp().MatchString(uri) +} + +func (s *MCPServer) handleListPrompts( + ctx context.Context, + id any, + request mcp.ListPromptsRequest, +) (*mcp.ListPromptsResult, *requestError) { + s.promptsMu.RLock() + prompts := make([]mcp.Prompt, 0, len(s.prompts)) + for _, prompt := range s.prompts { + prompts = append(prompts, prompt) + } + s.promptsMu.RUnlock() + + // sort prompts by name + sort.Slice(prompts, func(i, j int) bool { + return prompts[i].Name < prompts[j].Name + }) + promptsToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + prompts, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } + } + result := mcp.ListPromptsResult{ + Prompts: promptsToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, + } + return &result, nil +} + +func (s *MCPServer) handleGetPrompt( + ctx context.Context, + id any, + request mcp.GetPromptRequest, +) (*mcp.GetPromptResult, *requestError) { + s.promptsMu.RLock() + handler, ok := s.promptHandlers[request.Params.Name] + s.promptsMu.RUnlock() + + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("prompt '%s' not found: %w", request.Params.Name, ErrPromptNotFound), + } + } + + result, err := handler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + + return result, nil +} + +func (s *MCPServer) handleListTools( + ctx context.Context, + id any, + request mcp.ListToolsRequest, +) (*mcp.ListToolsResult, *requestError) { + // Get the base tools from the server + s.toolsMu.RLock() + tools := make([]mcp.Tool, 0, len(s.tools)) + + // Get all tool names for consistent ordering + toolNames := make([]string, 0, len(s.tools)) + for name := range s.tools { + toolNames = append(toolNames, name) + } + + // Sort the tool names for consistent ordering + sort.Strings(toolNames) + + // Add tools in sorted order + for _, name := range toolNames { + tools = append(tools, s.tools[name].Tool) + } + s.toolsMu.RUnlock() + + // Check if there are session-specific tools + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTools, ok := session.(SessionWithTools); ok { + if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { + // Override or add session-specific tools + // We need to create a map first to merge the tools properly + toolMap := make(map[string]mcp.Tool) + + // Add global tools first + for _, tool := range tools { + toolMap[tool.Name] = tool + } + + // Then override with session-specific tools + for name, serverTool := range sessionTools { + toolMap[name] = serverTool.Tool + } + + // Convert back to slice + tools = make([]mcp.Tool, 0, len(toolMap)) + for _, tool := range toolMap { + tools = append(tools, tool) + } + + // Sort again to maintain consistent ordering + sort.Slice(tools, func(i, j int) bool { + return tools[i].Name < tools[j].Name + }) + } + } + } + + // Apply tool filters if any are defined + s.toolFiltersMu.RLock() + if len(s.toolFilters) > 0 { + for _, filter := range s.toolFilters { + tools = filter(ctx, tools) + } + } + s.toolFiltersMu.RUnlock() + + // Apply pagination + toolsToReturn, nextCursor, err := listByPagination( + ctx, + s, + request.Params.Cursor, + tools, + ) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: err, + } + } + + result := mcp.ListToolsResult{ + Tools: toolsToReturn, + PaginatedResult: mcp.PaginatedResult{ + NextCursor: nextCursor, + }, + } + return &result, nil +} + +func (s *MCPServer) handleToolCall( + ctx context.Context, + id any, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, *requestError) { + // First check session-specific tools + var tool ServerTool + var ok bool + + session := ClientSessionFromContext(ctx) + if session != nil { + if sessionWithTools, typeAssertOk := session.(SessionWithTools); typeAssertOk { + if sessionTools := sessionWithTools.GetSessionTools(); sessionTools != nil { + var sessionOk bool + tool, sessionOk = sessionTools[request.Params.Name] + if sessionOk { + ok = true + } + } + } + } + + // If not found in session tools, check global tools + if !ok { + s.toolsMu.RLock() + tool, ok = s.tools[request.Params.Name] + s.toolsMu.RUnlock() + } + + if !ok { + return nil, &requestError{ + id: id, + code: mcp.INVALID_PARAMS, + err: fmt.Errorf("tool '%s' not found: %w", request.Params.Name, ErrToolNotFound), + } + } + + finalHandler := tool.Handler + + s.middlewareMu.RLock() + mw := s.toolHandlerMiddlewares + + // Apply middlewares in reverse order + for i := len(mw) - 1; i >= 0; i-- { + finalHandler = mw[i](finalHandler) + } + s.middlewareMu.RUnlock() + + result, err := finalHandler(ctx, request) + if err != nil { + return nil, &requestError{ + id: id, + code: mcp.INTERNAL_ERROR, + err: err, + } + } + + return result, nil +} + +func (s *MCPServer) handleNotification( + ctx context.Context, + notification mcp.JSONRPCNotification, +) mcp.JSONRPCMessage { + s.notificationHandlersMu.RLock() + handler, ok := s.notificationHandlers[notification.Method] + s.notificationHandlersMu.RUnlock() + + if ok { + handler(ctx, notification) + } + return nil +} + +func createResponse(id any, result any) mcp.JSONRPCMessage { + return mcp.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(id), + Result: result, + } +} + +func createErrorResponse( + id any, + code int, + message string, +) mcp.JSONRPCMessage { + return mcp.JSONRPCError{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(id), + Error: struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` + }{ + Code: code, + Message: message, + }, + } +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/session.go b/vendor/github.com/mark3labs/mcp-go/server/session.go new file mode 100644 index 0000000000..11ee8a2f1c --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/session.go @@ -0,0 +1,444 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// ClientSession represents an active session that can be used by MCPServer to interact with client. +type ClientSession interface { + // Initialize marks session as fully initialized and ready for notifications + Initialize() + // Initialized returns if session is ready to accept notifications + Initialized() bool + // NotificationChannel provides a channel suitable for sending notifications to client. + NotificationChannel() chan<- mcp.JSONRPCNotification + // SessionID is a unique identifier used to track user session. + SessionID() string +} + +// SessionWithLogging is an extension of ClientSession that can receive log message notifications and set log level +type SessionWithLogging interface { + ClientSession + // SetLogLevel sets the minimum log level + SetLogLevel(level mcp.LoggingLevel) + // GetLogLevel retrieves the minimum log level + GetLogLevel() mcp.LoggingLevel +} + +// SessionWithTools is an extension of ClientSession that can store session-specific tool data +type SessionWithTools interface { + ClientSession + // GetSessionTools returns the tools specific to this session, if any + // This method must be thread-safe for concurrent access + GetSessionTools() map[string]ServerTool + // SetSessionTools sets tools specific to this session + // This method must be thread-safe for concurrent access + SetSessionTools(tools map[string]ServerTool) +} + +// SessionWithClientInfo is an extension of ClientSession that can store client info +type SessionWithClientInfo interface { + ClientSession + // GetClientInfo returns the client information for this session + GetClientInfo() mcp.Implementation + // SetClientInfo sets the client information for this session + SetClientInfo(clientInfo mcp.Implementation) + // GetClientCapabilities returns the client capabilities for this session + GetClientCapabilities() mcp.ClientCapabilities + // SetClientCapabilities sets the client capabilities for this session + SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) +} + +// SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations +type SessionWithStreamableHTTPConfig interface { + ClientSession + // UpgradeToSSEWhenReceiveNotification upgrades the client-server communication to SSE stream when the server + // sends notifications to the client + // + // The protocol specification: + // - If the server response contains any JSON-RPC notifications, it MUST either: + // - Return Content-Type: text/event-stream to initiate an SSE stream, OR + // - Return Content-Type: application/json for a single JSON object + // - The client MUST support both response types. + // + // Reference: https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#sending-messages-to-the-server + UpgradeToSSEWhenReceiveNotification() +} + +// clientSessionKey is the context key for storing current client notification channel. +type clientSessionKey struct{} + +// ClientSessionFromContext retrieves current client notification context from context. +func ClientSessionFromContext(ctx context.Context) ClientSession { + if session, ok := ctx.Value(clientSessionKey{}).(ClientSession); ok { + return session + } + return nil +} + +// WithContext sets the current client session and returns the provided context +func (s *MCPServer) WithContext( + ctx context.Context, + session ClientSession, +) context.Context { + return context.WithValue(ctx, clientSessionKey{}, session) +} + +// RegisterSession saves session that should be notified in case if some server attributes changed. +func (s *MCPServer) RegisterSession( + ctx context.Context, + session ClientSession, +) error { + sessionID := session.SessionID() + if _, exists := s.sessions.LoadOrStore(sessionID, session); exists { + return ErrSessionExists + } + s.hooks.RegisterSession(ctx, session) + return nil +} + +func (s *MCPServer) buildLogNotification(notification mcp.LoggingMessageNotification) mcp.JSONRPCNotification { + return mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: notification.Method, + Params: mcp.NotificationParams{ + AdditionalFields: map[string]any{ + "level": notification.Params.Level, + "logger": notification.Params.Logger, + "data": notification.Params.Data, + }, + }, + }, + } +} + +func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp.LoggingMessageNotification) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + sessionLogging, ok := session.(SessionWithLogging) + if !ok { + return ErrSessionDoesNotSupportLogging + } + if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) { + return nil + } + return s.sendNotificationCore(ctx, session, s.buildLogNotification(notification)) +} + +func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) { + s.sessions.Range(func(k, v any) bool { + if session, ok := v.(ClientSession); ok && session.Initialized() { + select { + case session.NotificationChannel() <- notification: + // Successfully sent notification + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sessionID string, hooks *Hooks) { + ctx := context.Background() + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": notification.Method, + "sessionID": sessionID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) + }(session.SessionID(), hooks) + } + } + } + return true + }) +} + +func (s *MCPServer) sendNotificationToSpecificClient(session ClientSession, notification mcp.JSONRPCNotification) error { + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + select { + case session.NotificationChannel() <- notification: + return nil + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + ctx := context.Background() + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sID string, hooks *Hooks) { + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": notification.Method, + "sessionID": sID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) + }(session.SessionID(), hooks) + } + return ErrNotificationChannelBlocked + } +} + +func (s *MCPServer) SendLogMessageToSpecificClient(sessionID string, notification mcp.LoggingMessageNotification) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + session, ok := sessionValue.(ClientSession) + if !ok || !session.Initialized() { + return ErrSessionNotInitialized + } + sessionLogging, ok := session.(SessionWithLogging) + if !ok { + return ErrSessionDoesNotSupportLogging + } + if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) { + return nil + } + return s.sendNotificationToSpecificClient(session, s.buildLogNotification(notification)) +} + +// UnregisterSession removes from storage session that is shut down. +func (s *MCPServer) UnregisterSession( + ctx context.Context, + sessionID string, +) { + sessionValue, ok := s.sessions.LoadAndDelete(sessionID) + if !ok { + return + } + if session, ok := sessionValue.(ClientSession); ok { + s.hooks.UnregisterSession(ctx, session) + } +} + +// SendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) SendNotificationToAllClients( + method string, + params map[string]any, +) { + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + s.sendNotificationToAllClients(notification) +} + +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) sendNotificationCore( + ctx context.Context, + session ClientSession, + notification mcp.JSONRPCNotification, +) error { + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } + select { + case session.NotificationChannel() <- notification: + return nil + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + method := notification.Method + err := ErrNotificationChannelBlocked + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sessionID string, hooks *Hooks) { + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": method, + "sessionID": sessionID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) + }(session.SessionID(), hooks) + } + return ErrNotificationChannelBlocked + } +} + +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) SendNotificationToClient( + ctx context.Context, + method string, + params map[string]any, +) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + return s.sendNotificationCore(ctx, session, notification) +} + +// SendNotificationToSpecificClient sends a notification to a specific client by session ID +func (s *MCPServer) SendNotificationToSpecificClient( + sessionID string, + method string, + params map[string]any, +) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + session, ok := sessionValue.(ClientSession) + if !ok || !session.Initialized() { + return ErrSessionNotInitialized + } + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + return s.sendNotificationToSpecificClient(session, notification) +} + +// AddSessionTool adds a tool for a specific session +func (s *MCPServer) AddSessionTool(sessionID string, tool mcp.Tool, handler ToolHandlerFunc) error { + return s.AddSessionTools(sessionID, ServerTool{Tool: tool, Handler: handler}) +} + +// AddSessionTools adds tools for a specific session +func (s *MCPServer) AddSessionTools(sessionID string, tools ...ServerTool) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithTools) + if !ok { + return ErrSessionDoesNotSupportTools + } + + s.implicitlyRegisterToolCapabilities() + + // Get existing tools (this should return a thread-safe copy) + sessionTools := session.GetSessionTools() + + // Create a new map to avoid concurrent modification issues + newSessionTools := make(map[string]ServerTool, len(sessionTools)+len(tools)) + + // Copy existing tools + for k, v := range sessionTools { + newSessionTools[k] = v + } + + // Add new tools + for _, tool := range tools { + newSessionTools[tool.Tool.Name] = tool + } + + // Set the tools (this should be thread-safe) + session.SetSessionTools(newSessionTools) + + // It only makes sense to send tool notifications to initialized sessions -- + // if we're not initialized yet the client can't possibly have sent their + // initial tools/list message. + // + // For initialized sessions, honor tools.listChanged, which is specifically + // about whether notifications will be sent or not. + // see + if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged { + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The tools were successfully added, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/tools/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after adding tools: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} + +// DeleteSessionTools removes tools from a specific session +func (s *MCPServer) DeleteSessionTools(sessionID string, names ...string) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + + session, ok := sessionValue.(SessionWithTools) + if !ok { + return ErrSessionDoesNotSupportTools + } + + // Get existing tools (this should return a thread-safe copy) + sessionTools := session.GetSessionTools() + if sessionTools == nil { + return nil + } + + // Create a new map to avoid concurrent modification issues + newSessionTools := make(map[string]ServerTool, len(sessionTools)) + + // Copy existing tools except those being deleted + for k, v := range sessionTools { + newSessionTools[k] = v + } + + // Remove specified tools + for _, name := range names { + delete(newSessionTools, name) + } + + // Set the tools (this should be thread-safe) + session.SetSessionTools(newSessionTools) + + // It only makes sense to send tool notifications to initialized sessions -- + // if we're not initialized yet the client can't possibly have sent their + // initial tools/list message. + // + // For initialized sessions, honor tools.listChanged, which is specifically + // about whether notifications will be sent or not. + // see + if session.Initialized() && s.capabilities.tools != nil && s.capabilities.tools.listChanged { + // Send notification only to this session + if err := s.SendNotificationToSpecificClient(sessionID, "notifications/tools/list_changed", nil); err != nil { + // Log the error but don't fail the operation + // The tools were successfully deleted, but notification failed + if s.hooks != nil && len(s.hooks.OnError) > 0 { + hooks := s.hooks + go func(sID string, hooks *Hooks) { + ctx := context.Background() + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": "notifications/tools/list_changed", + "sessionID": sID, + }, fmt.Errorf("failed to send notification after deleting tools: %w", err)) + }(sessionID, hooks) + } + } + } + + return nil +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/sse.go b/vendor/github.com/mark3labs/mcp-go/server/sse.go new file mode 100644 index 0000000000..9c9766cf3e --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/sse.go @@ -0,0 +1,751 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "log" + "net/http" + "net/http/httptest" + "net/url" + "path" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + + "github.com/mark3labs/mcp-go/mcp" +) + +// sseSession represents an active SSE connection. +type sseSession struct { + done chan struct{} + eventQueue chan string // Channel for queuing events + sessionID string + requestID atomic.Int64 + notificationChannel chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + tools sync.Map // stores session-specific tools + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities +} + +// SSEContextFunc is a function that takes an existing context and the current +// request and returns a potentially modified context based on the request +// content. This can be used to inject context values from headers, for example. +type SSEContextFunc func(ctx context.Context, r *http.Request) context.Context + +// DynamicBasePathFunc allows the user to provide a function to generate the +// base path for a given request and sessionID. This is useful for cases where +// the base path is not known at the time of SSE server creation, such as when +// using a reverse proxy or when the base path is dynamically generated. The +// function should return the base path (e.g., "/mcp/tenant123"). +type DynamicBasePathFunc func(r *http.Request, sessionID string) string + +func (s *sseSession) SessionID() string { + return s.sessionID +} + +func (s *sseSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notificationChannel +} + +func (s *sseSession) Initialize() { + // set default logging level + s.loggingLevel.Store(mcp.LoggingLevelError) + s.initialized.Store(true) +} + +func (s *sseSession) Initialized() bool { + return s.initialized.Load() +} + +func (s *sseSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *sseSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +func (s *sseSession) GetSessionTools() map[string]ServerTool { + tools := make(map[string]ServerTool) + s.tools.Range(func(key, value any) bool { + if tool, ok := value.(ServerTool); ok { + tools[key.(string)] = tool + } + return true + }) + return tools +} + +func (s *sseSession) SetSessionTools(tools map[string]ServerTool) { + // Clear existing tools + s.tools.Clear() + + // Set new tools + for name, tool := range tools { + s.tools.Store(name, tool) + } +} + +func (s *sseSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *sseSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +var ( + _ ClientSession = (*sseSession)(nil) + _ SessionWithTools = (*sseSession)(nil) + _ SessionWithLogging = (*sseSession)(nil) + _ SessionWithClientInfo = (*sseSession)(nil) +) + +// SSEServer implements a Server-Sent Events (SSE) based MCP server. +// It provides real-time communication capabilities over HTTP using the SSE protocol. +type SSEServer struct { + server *MCPServer + baseURL string + basePath string + appendQueryToMessageEndpoint bool + useFullURLForMessageEndpoint bool + messageEndpoint string + sseEndpoint string + sessions sync.Map + srv *http.Server + contextFunc SSEContextFunc + dynamicBasePathFunc DynamicBasePathFunc + + keepAlive bool + keepAliveInterval time.Duration + + mu sync.RWMutex +} + +// SSEOption defines a function type for configuring SSEServer +type SSEOption func(*SSEServer) + +// WithBaseURL sets the base URL for the SSE server +func WithBaseURL(baseURL string) SSEOption { + return func(s *SSEServer) { + if baseURL != "" { + u, err := url.Parse(baseURL) + if err != nil { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + // Check if the host is empty or only contains a port + if u.Host == "" || strings.HasPrefix(u.Host, ":") { + return + } + if len(u.Query()) > 0 { + return + } + } + s.baseURL = strings.TrimSuffix(baseURL, "/") + } +} + +// WithStaticBasePath adds a new option for setting a static base path +func WithStaticBasePath(basePath string) SSEOption { + return func(s *SSEServer) { + s.basePath = normalizeURLPath(basePath) + } +} + +// WithBasePath adds a new option for setting a static base path. +// +// Deprecated: Use WithStaticBasePath instead. This will be removed in a future version. +// +//go:deprecated +func WithBasePath(basePath string) SSEOption { + return WithStaticBasePath(basePath) +} + +// WithDynamicBasePath accepts a function for generating the base path. This is +// useful for cases where the base path is not known at the time of SSE server +// creation, such as when using a reverse proxy or when the server is mounted +// at a dynamic path. +func WithDynamicBasePath(fn DynamicBasePathFunc) SSEOption { + return func(s *SSEServer) { + if fn != nil { + s.dynamicBasePathFunc = func(r *http.Request, sid string) string { + bp := fn(r, sid) + return normalizeURLPath(bp) + } + } + } +} + +// WithMessageEndpoint sets the message endpoint path +func WithMessageEndpoint(endpoint string) SSEOption { + return func(s *SSEServer) { + s.messageEndpoint = endpoint + } +} + +// WithAppendQueryToMessageEndpoint configures the SSE server to append the original request's +// query parameters to the message endpoint URL that is sent to clients during the SSE connection +// initialization. This is useful when you need to preserve query parameters from the initial +// SSE connection request and carry them over to subsequent message requests, maintaining +// context or authentication details across the communication channel. +func WithAppendQueryToMessageEndpoint() SSEOption { + return func(s *SSEServer) { + s.appendQueryToMessageEndpoint = true + } +} + +// WithUseFullURLForMessageEndpoint controls whether the SSE server returns a complete URL (including baseURL) +// or just the path portion for the message endpoint. Set to false when clients will concatenate +// the baseURL themselves to avoid malformed URLs like "http://localhost/mcphttp://localhost/mcp/message". +func WithUseFullURLForMessageEndpoint(useFullURLForMessageEndpoint bool) SSEOption { + return func(s *SSEServer) { + s.useFullURLForMessageEndpoint = useFullURLForMessageEndpoint + } +} + +// WithSSEEndpoint sets the SSE endpoint path +func WithSSEEndpoint(endpoint string) SSEOption { + return func(s *SSEServer) { + s.sseEndpoint = endpoint + } +} + +// WithHTTPServer sets the HTTP server instance. +// NOTE: When providing a custom HTTP server, you must handle routing yourself +// If routing is not set up, the server will start but won't handle any MCP requests. +func WithHTTPServer(srv *http.Server) SSEOption { + return func(s *SSEServer) { + s.srv = srv + } +} + +func WithKeepAliveInterval(keepAliveInterval time.Duration) SSEOption { + return func(s *SSEServer) { + s.keepAlive = true + s.keepAliveInterval = keepAliveInterval + } +} + +func WithKeepAlive(keepAlive bool) SSEOption { + return func(s *SSEServer) { + s.keepAlive = keepAlive + } +} + +// WithSSEContextFunc sets a function that will be called to customise the context +// to the server using the incoming request. +func WithSSEContextFunc(fn SSEContextFunc) SSEOption { + return func(s *SSEServer) { + s.contextFunc = fn + } +} + +// NewSSEServer creates a new SSE server instance with the given MCP server and options. +func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer { + s := &SSEServer{ + server: server, + sseEndpoint: "/sse", + messageEndpoint: "/message", + useFullURLForMessageEndpoint: true, + keepAlive: false, + keepAliveInterval: 10 * time.Second, + } + + // Apply all options + for _, opt := range opts { + opt(s) + } + + return s +} + +// NewTestServer creates a test server for testing purposes +func NewTestServer(server *MCPServer, opts ...SSEOption) *httptest.Server { + sseServer := NewSSEServer(server, opts...) + + testServer := httptest.NewServer(sseServer) + sseServer.baseURL = testServer.URL + return testServer +} + +// Start begins serving SSE connections on the specified address. +// It sets up HTTP handlers for SSE and message endpoints. +func (s *SSEServer) Start(addr string) error { + s.mu.Lock() + if s.srv == nil { + s.srv = &http.Server{ + Addr: addr, + Handler: s, + } + } else { + if s.srv.Addr == "" { + s.srv.Addr = addr + } else if s.srv.Addr != addr { + return fmt.Errorf("conflicting listen address: WithHTTPServer(%q) vs Start(%q)", s.srv.Addr, addr) + } + } + srv := s.srv + s.mu.Unlock() + + return srv.ListenAndServe() +} + +// Shutdown gracefully stops the SSE server, closing all active sessions +// and shutting down the HTTP server. +func (s *SSEServer) Shutdown(ctx context.Context) error { + s.mu.RLock() + srv := s.srv + s.mu.RUnlock() + + if srv != nil { + s.sessions.Range(func(key, value any) bool { + if session, ok := value.(*sseSession); ok { + close(session.done) + } + s.sessions.Delete(key) + return true + }) + + return srv.Shutdown(ctx) + } + return nil +} + +// handleSSE handles incoming SSE connection requests. +// It sets up appropriate headers and creates a new session for the client. +func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Access-Control-Allow-Origin", "*") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + sessionID := uuid.New().String() + session := &sseSession{ + done: make(chan struct{}), + eventQueue: make(chan string, 100), // Buffer for events + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + } + + s.sessions.Store(sessionID, session) + defer s.sessions.Delete(sessionID) + + if err := s.server.RegisterSession(r.Context(), session); err != nil { + http.Error( + w, + fmt.Sprintf("Session registration failed: %v", err), + http.StatusInternalServerError, + ) + return + } + defer s.server.UnregisterSession(r.Context(), sessionID) + + // Start notification handler for this session + go func() { + for { + select { + case notification := <-session.notificationChannel: + eventData, err := json.Marshal(notification) + if err == nil { + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + // Event queued successfully + case <-session.done: + return + } + } + case <-session.done: + return + case <-r.Context().Done(): + return + } + } + }() + + // Start keep alive : ping + if s.keepAlive { + go func() { + ticker := time.NewTicker(s.keepAliveInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + message := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(session.requestID.Add(1)), + Request: mcp.Request{ + Method: "ping", + }, + } + messageBytes, _ := json.Marshal(message) + pingMsg := fmt.Sprintf("event: message\ndata:%s\n\n", messageBytes) + select { + case session.eventQueue <- pingMsg: + // Message sent successfully + case <-session.done: + return + } + case <-session.done: + return + case <-r.Context().Done(): + return + } + } + }() + } + + // Send the initial endpoint event + endpoint := s.GetMessageEndpointForClient(r, sessionID) + if s.appendQueryToMessageEndpoint && len(r.URL.RawQuery) > 0 { + endpoint += "&" + r.URL.RawQuery + } + fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", endpoint) + flusher.Flush() + + // Main event loop - this runs in the HTTP handler goroutine + for { + select { + case event := <-session.eventQueue: + // Write the event to the response + fmt.Fprint(w, event) + flusher.Flush() + case <-r.Context().Done(): + close(session.done) + return + case <-session.done: + return + } + } +} + +// GetMessageEndpointForClient returns the appropriate message endpoint URL with session ID +// for the given request. This is the canonical way to compute the message endpoint for a client. +// It handles both dynamic and static path modes, and honors the WithUseFullURLForMessageEndpoint flag. +func (s *SSEServer) GetMessageEndpointForClient(r *http.Request, sessionID string) string { + basePath := s.basePath + if s.dynamicBasePathFunc != nil { + basePath = s.dynamicBasePathFunc(r, sessionID) + } + + endpointPath := normalizeURLPath(basePath, s.messageEndpoint) + if s.useFullURLForMessageEndpoint && s.baseURL != "" { + endpointPath = s.baseURL + endpointPath + } + + return fmt.Sprintf("%s?sessionId=%s", endpointPath, sessionID) +} + +// handleMessage processes incoming JSON-RPC messages from clients and sends responses +// back through the SSE connection and 202 code to HTTP response. +func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed") + return + } + + sessionID := r.URL.Query().Get("sessionId") + if sessionID == "" { + s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") + return + } + sessionI, ok := s.sessions.Load(sessionID) + if !ok { + s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Invalid session ID") + return + } + session := sessionI.(*sseSession) + + // Set the client context before handling the message + ctx := s.server.WithContext(r.Context(), session) + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + // Parse message as raw JSON + var rawMessage json.RawMessage + if err := json.NewDecoder(r.Body).Decode(&rawMessage); err != nil { + s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "Parse error") + return + } + + // Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled. + // this is required because the http ctx will be canceled when the client disconnects + detachedCtx := context.WithoutCancel(ctx) + + // quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE + w.WriteHeader(http.StatusAccepted) + + // Create a new context for handling the message that will be canceled when the message handling is done + messageCtx := context.WithValue(detachedCtx, requestHeader, r.Header) + messageCtx, cancel := context.WithCancel(messageCtx) + + go func(ctx context.Context) { + defer cancel() + // Use the context that will be canceled when session is done + // Process message through MCPServer + response := s.server.HandleMessage(ctx, rawMessage) + // Only send response if there is one (not for notifications) + if response != nil { + var message string + if eventData, err := json.Marshal(response); err != nil { + // If there is an error marshalling the response, send a generic error response + log.Printf("failed to marshal response: %v", err) + message = "event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n" + } else { + message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData) + } + + // Queue the event for sending via SSE + select { + case session.eventQueue <- message: + // Event queued successfully + case <-session.done: + // Session is closed, don't try to queue + default: + // Queue is full, log this situation + log.Printf("Event queue full for session %s", sessionID) + } + } + }(messageCtx) +} + +// writeJSONRPCError writes a JSON-RPC error response with the given error details. +func (s *SSEServer) writeJSONRPCError( + w http.ResponseWriter, + id any, + code int, + message string, +) { + response := createErrorResponse(id, code, message) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + if err := json.NewEncoder(w).Encode(response); err != nil { + http.Error( + w, + fmt.Sprintf("Failed to encode response: %v", err), + http.StatusInternalServerError, + ) + return + } +} + +// SendEventToSession sends an event to a specific SSE session identified by sessionID. +// Returns an error if the session is not found or closed. +func (s *SSEServer) SendEventToSession( + sessionID string, + event any, +) error { + sessionI, ok := s.sessions.Load(sessionID) + if !ok { + return fmt.Errorf("session not found: %s", sessionID) + } + session := sessionI.(*sseSession) + + eventData, err := json.Marshal(event) + if err != nil { + return err + } + + // Queue the event for sending via SSE + select { + case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData): + return nil + case <-session.done: + return fmt.Errorf("session closed") + default: + return fmt.Errorf("event queue full") + } +} + +func (s *SSEServer) GetUrlPath(input string) (string, error) { + parse, err := url.Parse(input) + if err != nil { + return "", fmt.Errorf("failed to parse URL %s: %w", input, err) + } + return parse.Path, nil +} + +func (s *SSEServer) CompleteSseEndpoint() (string, error) { + if s.dynamicBasePathFunc != nil { + return "", &ErrDynamicPathConfig{Method: "CompleteSseEndpoint"} + } + + path := normalizeURLPath(s.basePath, s.sseEndpoint) + return s.baseURL + path, nil +} + +func (s *SSEServer) CompleteSsePath() string { + path, err := s.CompleteSseEndpoint() + if err != nil { + return normalizeURLPath(s.basePath, s.sseEndpoint) + } + urlPath, err := s.GetUrlPath(path) + if err != nil { + return normalizeURLPath(s.basePath, s.sseEndpoint) + } + return urlPath +} + +func (s *SSEServer) CompleteMessageEndpoint() (string, error) { + if s.dynamicBasePathFunc != nil { + return "", &ErrDynamicPathConfig{Method: "CompleteMessageEndpoint"} + } + path := normalizeURLPath(s.basePath, s.messageEndpoint) + return s.baseURL + path, nil +} + +func (s *SSEServer) CompleteMessagePath() string { + path, err := s.CompleteMessageEndpoint() + if err != nil { + return normalizeURLPath(s.basePath, s.messageEndpoint) + } + urlPath, err := s.GetUrlPath(path) + if err != nil { + return normalizeURLPath(s.basePath, s.messageEndpoint) + } + return urlPath +} + +// SSEHandler returns an http.Handler for the SSE endpoint. +// +// This method allows you to mount the SSE handler at any arbitrary path +// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is +// intended for advanced scenarios where you want to control the routing or +// support dynamic segments. +// +// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios, +// you must use the WithDynamicBasePath option to ensure the correct base path +// is communicated to clients. +// +// Example usage: +// +// // Advanced/dynamic: +// sseServer := NewSSEServer(mcpServer, +// WithDynamicBasePath(func(r *http.Request, sessionID string) string { +// tenant := r.PathValue("tenant") +// return "/mcp/" + tenant +// }), +// WithBaseURL("http://localhost:8080") +// ) +// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) +// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) +// +// For non-dynamic cases, use ServeHTTP method instead. +func (s *SSEServer) SSEHandler() http.Handler { + return http.HandlerFunc(s.handleSSE) +} + +// MessageHandler returns an http.Handler for the message endpoint. +// +// This method allows you to mount the message handler at any arbitrary path +// using your own router (e.g. net/http, gorilla/mux, chi, etc.). It is +// intended for advanced scenarios where you want to control the routing or +// support dynamic segments. +// +// IMPORTANT: When using this handler in advanced/dynamic mounting scenarios, +// you must use the WithDynamicBasePath option to ensure the correct base path +// is communicated to clients. +// +// Example usage: +// +// // Advanced/dynamic: +// sseServer := NewSSEServer(mcpServer, +// WithDynamicBasePath(func(r *http.Request, sessionID string) string { +// tenant := r.PathValue("tenant") +// return "/mcp/" + tenant +// }), +// WithBaseURL("http://localhost:8080") +// ) +// mux.Handle("/mcp/{tenant}/sse", sseServer.SSEHandler()) +// mux.Handle("/mcp/{tenant}/message", sseServer.MessageHandler()) +// +// For non-dynamic cases, use ServeHTTP method instead. +func (s *SSEServer) MessageHandler() http.Handler { + return http.HandlerFunc(s.handleMessage) +} + +// ServeHTTP implements the http.Handler interface. +func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if s.dynamicBasePathFunc != nil { + http.Error( + w, + (&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(), + http.StatusInternalServerError, + ) + return + } + path := r.URL.Path + // Use exact path matching rather than Contains + ssePath := s.CompleteSsePath() + if ssePath != "" && path == ssePath { + s.handleSSE(w, r) + return + } + messagePath := s.CompleteMessagePath() + if messagePath != "" && path == messagePath { + s.handleMessage(w, r) + return + } + + http.NotFound(w, r) +} + +// normalizeURLPath joins path elements like path.Join but ensures the +// result always starts with a leading slash and never ends with a slash +func normalizeURLPath(elem ...string) string { + joined := path.Join(elem...) + + // Ensure leading slash + if !strings.HasPrefix(joined, "/") { + joined = "/" + joined + } + + // Remove trailing slash if not just "/" + if len(joined) > 1 && strings.HasSuffix(joined, "/") { + joined = joined[:len(joined)-1] + } + + return joined +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/stdio.go b/vendor/github.com/mark3labs/mcp-go/server/stdio.go new file mode 100644 index 0000000000..8c270e18b7 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/stdio.go @@ -0,0 +1,592 @@ +package server + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "log" + "os" + "os/signal" + "sync" + "sync/atomic" + "syscall" + + "github.com/mark3labs/mcp-go/mcp" +) + +// StdioContextFunc is a function that takes an existing context and returns +// a potentially modified context. +// This can be used to inject context values from environment variables, +// for example. +type StdioContextFunc func(ctx context.Context) context.Context + +// StdioServer wraps a MCPServer and handles stdio communication. +// It provides a simple way to create command-line MCP servers that +// communicate via standard input/output streams using JSON-RPC messages. +type StdioServer struct { + server *MCPServer + errLogger *log.Logger + contextFunc StdioContextFunc + + // Thread-safe tool call processing + toolCallQueue chan *toolCallWork + workerWg sync.WaitGroup + workerPoolSize int + queueSize int + writeMu sync.Mutex // Protects concurrent writes +} + +// toolCallWork represents a queued tool call request +type toolCallWork struct { + ctx context.Context + message json.RawMessage + writer io.Writer +} + +// StdioOption defines a function type for configuring StdioServer +type StdioOption func(*StdioServer) + +// WithErrorLogger sets the error logger for the server +func WithErrorLogger(logger *log.Logger) StdioOption { + return func(s *StdioServer) { + s.errLogger = logger + } +} + +// WithStdioContextFunc sets a function that will be called to customise the context +// to the server. Note that the stdio server uses the same context for all requests, +// so this function will only be called once per server instance. +func WithStdioContextFunc(fn StdioContextFunc) StdioOption { + return func(s *StdioServer) { + s.contextFunc = fn + } +} + +// WithWorkerPoolSize sets the number of workers for processing tool calls +func WithWorkerPoolSize(size int) StdioOption { + return func(s *StdioServer) { + const maxWorkerPoolSize = 100 + if size > 0 && size <= maxWorkerPoolSize { + s.workerPoolSize = size + } else if size > maxWorkerPoolSize { + s.errLogger.Printf("Worker pool size %d exceeds maximum (%d), using maximum", size, maxWorkerPoolSize) + s.workerPoolSize = maxWorkerPoolSize + } + } +} + +// WithQueueSize sets the size of the tool call queue +func WithQueueSize(size int) StdioOption { + return func(s *StdioServer) { + const maxQueueSize = 10000 + if size > 0 && size <= maxQueueSize { + s.queueSize = size + } else if size > maxQueueSize { + s.errLogger.Printf("Queue size %d exceeds maximum (%d), using maximum", size, maxQueueSize) + s.queueSize = maxQueueSize + } + } +} + +// stdioSession is a static client session, since stdio has only one client. +type stdioSession struct { + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingMu sync.RWMutex // protects pendingRequests +} + +// samplingResponse represents a response to a sampling request +type samplingResponse struct { + result *mcp.CreateMessageResult + err error +} + +func (s *stdioSession) SessionID() string { + return "stdio" +} + +func (s *stdioSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notifications +} + +func (s *stdioSession) Initialize() { + // set default logging level + s.loggingLevel.Store(mcp.LoggingLevelError) + s.initialized.Store(true) +} + +func (s *stdioSession) Initialized() bool { + return s.initialized.Load() +} + +func (s *stdioSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *stdioSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *stdioSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +// RequestSampling sends a sampling request to the client and waits for the response. +func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *samplingResponse, 1) + s.pendingMu.Lock() + s.pendingRequests[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingRequests, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.CreateMessageParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodSamplingCreateMessage), + Params: request.CreateMessageParams, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal sampling request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write sampling request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + +// SetWriter sets the writer for sending requests to the client. +func (s *stdioSession) SetWriter(writer io.Writer) { + s.mu.Lock() + defer s.mu.Unlock() + s.writer = writer +} + +var ( + _ ClientSession = (*stdioSession)(nil) + _ SessionWithLogging = (*stdioSession)(nil) + _ SessionWithClientInfo = (*stdioSession)(nil) + _ SessionWithSampling = (*stdioSession)(nil) +) + +var stdioSessionInstance = stdioSession{ + notifications: make(chan mcp.JSONRPCNotification, 100), + pendingRequests: make(map[int64]chan *samplingResponse), +} + +// NewStdioServer creates a new stdio server wrapper around an MCPServer. +// It initializes the server with a default error logger that discards all output. +func NewStdioServer(server *MCPServer) *StdioServer { + return &StdioServer{ + server: server, + errLogger: log.New( + os.Stderr, + "", + log.LstdFlags, + ), // Default to discarding logs + workerPoolSize: 5, // Default worker pool size + queueSize: 100, // Default queue size + } +} + +// SetErrorLogger configures where error messages from the StdioServer are logged. +// The provided logger will receive all error messages generated during server operation. +func (s *StdioServer) SetErrorLogger(logger *log.Logger) { + s.errLogger = logger +} + +// SetContextFunc sets a function that will be called to customise the context +// to the server. Note that the stdio server uses the same context for all requests, +// so this function will only be called once per server instance. +func (s *StdioServer) SetContextFunc(fn StdioContextFunc) { + s.contextFunc = fn +} + +// handleNotifications continuously processes notifications from the session's notification channel +// and writes them to the provided output. It runs until the context is cancelled. +// Any errors encountered while writing notifications are logged but do not stop the handler. +func (s *StdioServer) handleNotifications(ctx context.Context, stdout io.Writer) { + for { + select { + case notification := <-stdioSessionInstance.notifications: + if err := s.writeResponse(notification, stdout); err != nil { + s.errLogger.Printf("Error writing notification: %v", err) + } + case <-ctx.Done(): + return + } + } +} + +// processInputStream continuously reads and processes messages from the input stream. +// It handles EOF gracefully as a normal termination condition. +// The function returns when either: +// - The context is cancelled (returns context.Err()) +// - EOF is encountered (returns nil) +// - An error occurs while reading or processing messages (returns the error) +func (s *StdioServer) processInputStream(ctx context.Context, reader *bufio.Reader, stdout io.Writer) error { + for { + if err := ctx.Err(); err != nil { + return err + } + + line, err := s.readNextLine(ctx, reader) + if err != nil { + if err == io.EOF { + return nil + } + s.errLogger.Printf("Error reading input: %v", err) + return err + } + + if err := s.processMessage(ctx, line, stdout); err != nil { + if err == io.EOF { + return nil + } + s.errLogger.Printf("Error handling message: %v", err) + return err + } + } +} + +// toolCallWorker processes tool calls from the queue +func (s *StdioServer) toolCallWorker(ctx context.Context) { + defer s.workerWg.Done() + + for { + select { + case work, ok := <-s.toolCallQueue: + if !ok { + // Channel closed, exit worker + return + } + // Process the tool call + response := s.server.HandleMessage(work.ctx, work.message) + if response != nil { + if err := s.writeResponse(response, work.writer); err != nil { + s.errLogger.Printf("Error writing tool response: %v", err) + } + } + case <-ctx.Done(): + return + } + } +} + +// readNextLine reads a single line from the input reader in a context-aware manner. +// It uses channels to make the read operation cancellable via context. +// Returns the read line and any error encountered. If the context is cancelled, +// returns an empty string and the context's error. EOF is returned when the input +// stream is closed. +func (s *StdioServer) readNextLine(ctx context.Context, reader *bufio.Reader) (string, error) { + type result struct { + line string + err error + } + + resultCh := make(chan result, 1) + + go func() { + line, err := reader.ReadString('\n') + resultCh <- result{line: line, err: err} + }() + + select { + case <-ctx.Done(): + return "", nil + case res := <-resultCh: + return res.line, res.err + } +} + +// Listen starts listening for JSON-RPC messages on the provided input and writes responses to the provided output. +// It runs until the context is cancelled or an error occurs. +// Returns an error if there are issues with reading input or writing output. +func (s *StdioServer) Listen( + ctx context.Context, + stdin io.Reader, + stdout io.Writer, +) error { + // Initialize the tool call queue + s.toolCallQueue = make(chan *toolCallWork, s.queueSize) + + // Set a static client context since stdio only has one client + if err := s.server.RegisterSession(ctx, &stdioSessionInstance); err != nil { + return fmt.Errorf("register session: %w", err) + } + defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) + ctx = s.server.WithContext(ctx, &stdioSessionInstance) + + // Set the writer for sending requests to the client + stdioSessionInstance.SetWriter(stdout) + + // Add in any custom context. + if s.contextFunc != nil { + ctx = s.contextFunc(ctx) + } + + reader := bufio.NewReader(stdin) + + // Start worker pool for tool calls + for i := 0; i < s.workerPoolSize; i++ { + s.workerWg.Add(1) + go s.toolCallWorker(ctx) + } + + // Start notification handler + go s.handleNotifications(ctx, stdout) + + // Process input stream + err := s.processInputStream(ctx, reader, stdout) + + // Shutdown workers gracefully + close(s.toolCallQueue) + s.workerWg.Wait() + + return err +} + +// processMessage handles a single JSON-RPC message and writes the response. +// It parses the message, processes it through the wrapped MCPServer, and writes any response. +// Returns an error if there are issues with message processing or response writing. +func (s *StdioServer) processMessage( + ctx context.Context, + line string, + writer io.Writer, +) error { + // If line is empty, likely due to ctx cancellation + if len(line) == 0 { + return nil + } + + // Parse the message as raw JSON + var rawMessage json.RawMessage + if err := json.Unmarshal([]byte(line), &rawMessage); err != nil { + response := createErrorResponse(nil, mcp.PARSE_ERROR, "Parse error") + return s.writeResponse(response, writer) + } + + // Check if this is a response to a sampling request + if s.handleSamplingResponse(rawMessage) { + return nil + } + + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) + var baseMessage struct { + Method string `json:"method"` + } + if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" { + // Queue tool calls for processing by workers + select { + case s.toolCallQueue <- &toolCallWork{ + ctx: ctx, + message: rawMessage, + writer: writer, + }: + return nil + case <-ctx.Done(): + return ctx.Err() + default: + // Queue is full, process synchronously as fallback + s.errLogger.Printf("Tool call queue full, processing synchronously") + response := s.server.HandleMessage(ctx, rawMessage) + if response != nil { + return s.writeResponse(response, writer) + } + return nil + } + } + + // Handle other messages synchronously + response := s.server.HandleMessage(ctx, rawMessage) + + // Only write response if there is one (not for notifications) + if response != nil { + if err := s.writeResponse(response, writer); err != nil { + return fmt.Errorf("failed to write response: %w", err) + } + } + + return nil +} + +// handleSamplingResponse checks if the message is a response to a sampling request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleSamplingResponse(rawMessage) +} + +// handleSamplingResponse handles incoming sampling responses for this session +func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + idInt64, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Look for a pending request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingRequests[idInt64] + s.pendingMu.RUnlock() + + if !exists { + return false + } // Parse and send the response + samplingResp := &samplingResponse{} + + if response.Error != nil { + samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message) + } else { + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.Result, &result); err != nil { + samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) + } else { + samplingResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- samplingResp: + default: + // Channel is full or closed, ignore + } + + return true +} + +// writeResponse marshals and writes a JSON-RPC response message followed by a newline. +// Returns an error if marshaling or writing fails. +func (s *StdioServer) writeResponse( + response mcp.JSONRPCMessage, + writer io.Writer, +) error { + responseBytes, err := json.Marshal(response) + if err != nil { + return err + } + + // Protect concurrent writes + s.writeMu.Lock() + defer s.writeMu.Unlock() + + // Write response followed by newline + if _, err := fmt.Fprintf(writer, "%s\n", responseBytes); err != nil { + return err + } + + return nil +} + +// ServeStdio is a convenience function that creates and starts a StdioServer with os.Stdin and os.Stdout. +// It sets up signal handling for graceful shutdown on SIGTERM and SIGINT. +// Returns an error if the server encounters any issues during operation. +func ServeStdio(server *MCPServer, opts ...StdioOption) error { + s := NewStdioServer(server) + + for _, opt := range opts { + opt(s) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Set up signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) + + go func() { + <-sigChan + cancel() + }() + + return s.Listen(ctx, os.Stdin, os.Stdout) +} diff --git a/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go new file mode 100644 index 0000000000..c97d9b7471 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/server/streamable_http.go @@ -0,0 +1,939 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "io" + "mime" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" +) + +// StreamableHTTPOption defines a function type for configuring StreamableHTTPServer +type StreamableHTTPOption func(*StreamableHTTPServer) + +// WithEndpointPath sets the endpoint path for the server. +// The default is "/mcp". +// It's only works for `Start` method. When used as a http.Handler, it has no effect. +func WithEndpointPath(endpointPath string) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + // Normalize the endpoint path to ensure it starts with a slash and doesn't end with one + normalizedPath := "/" + strings.Trim(endpointPath, "/") + s.endpointPath = normalizedPath + } +} + +// WithStateLess sets the server to stateless mode. +// If true, the server will manage no session information. Every request will be treated +// as a new session. No session id returned to the client. +// The default is false. +// +// Notice: This is a convenience method. It's identical to set WithSessionIdManager option +// to StatelessSessionIdManager. +func WithStateLess(stateLess bool) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + if stateLess { + s.sessionIdManager = &StatelessSessionIdManager{} + } + } +} + +// WithSessionIdManager sets a custom session id generator for the server. +// By default, the server will use SimpleStatefulSessionIdGenerator, which generates +// session ids with uuid, and it's insecure. +// Notice: it will override the WithStateLess option. +func WithSessionIdManager(manager SessionIdManager) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.sessionIdManager = manager + } +} + +// WithHeartbeatInterval sets the heartbeat interval. Positive interval means the +// server will send a heartbeat to the client through the GET connection, to keep +// the connection alive from being closed by the network infrastructure (e.g. +// gateways). If the client does not establish a GET connection, it has no +// effect. The default is not to send heartbeats. +func WithHeartbeatInterval(interval time.Duration) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.listenHeartbeatInterval = interval + } +} + +// WithHTTPContextFunc sets a function that will be called to customise the context +// to the server using the incoming request. +// This can be used to inject context values from headers, for example. +func WithHTTPContextFunc(fn HTTPContextFunc) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.contextFunc = fn + } +} + +// WithStreamableHTTPServer sets the HTTP server instance for StreamableHTTPServer. +// NOTE: When providing a custom HTTP server, you must handle routing yourself +// If routing is not set up, the server will start but won't handle any MCP requests. +func WithStreamableHTTPServer(srv *http.Server) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.httpServer = srv + } +} + +// WithLogger sets the logger for the server +func WithLogger(logger util.Logger) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.logger = logger + } +} + +// WithTLSCert sets the TLS certificate and key files for HTTPS support. +// Both certFile and keyFile must be provided to enable TLS. +func WithTLSCert(certFile, keyFile string) StreamableHTTPOption { + return func(s *StreamableHTTPServer) { + s.tlsCertFile = certFile + s.tlsKeyFile = keyFile + } +} + +// StreamableHTTPServer implements a Streamable-http based MCP server. +// It communicates with clients over HTTP protocol, supporting both direct HTTP responses, and SSE streams. +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http +// +// Usage: +// +// server := NewStreamableHTTPServer(mcpServer) +// server.Start(":8080") // The final url for client is http://xxxx:8080/mcp by default +// +// or the server itself can be used as a http.Handler, which is convenient to +// integrate with existing http servers, or advanced usage: +// +// handler := NewStreamableHTTPServer(mcpServer) +// http.Handle("/streamable-http", handler) +// http.ListenAndServe(":8080", nil) +// +// Notice: +// Except for the GET handlers(listening), the POST handlers(request/notification) will +// not trigger the session registration. So the methods like `SendNotificationToSpecificClient` +// or `hooks.onRegisterSession` will not be triggered for POST messages. +// +// The current implementation does not support the following features from the specification: +// - Stream Resumability +type StreamableHTTPServer struct { + server *MCPServer + sessionTools *sessionToolsStore + sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64) + activeSessions sync.Map // sessionId --> *streamableHttpSession (for sampling responses) + + httpServer *http.Server + mu sync.RWMutex + + endpointPath string + contextFunc HTTPContextFunc + sessionIdManager SessionIdManager + listenHeartbeatInterval time.Duration + logger util.Logger + sessionLogLevels *sessionLogLevelsStore + + tlsCertFile string + tlsKeyFile string +} + +// NewStreamableHTTPServer creates a new streamable-http server instance +func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *StreamableHTTPServer { + s := &StreamableHTTPServer{ + server: server, + sessionTools: newSessionToolsStore(), + sessionLogLevels: newSessionLogLevelsStore(), + endpointPath: "/mcp", + sessionIdManager: &InsecureStatefulSessionIdManager{}, + logger: util.DefaultLogger(), + } + + // Apply all options + for _, opt := range opts { + opt(s) + } + return s +} + +// ServeHTTP implements the http.Handler interface. +func (s *StreamableHTTPServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodPost: + s.handlePost(w, r) + case http.MethodGet: + s.handleGet(w, r) + case http.MethodDelete: + s.handleDelete(w, r) + default: + http.NotFound(w, r) + } +} + +// Start begins serving the http server on the specified address and path +// (endpointPath). like: +// +// s.Start(":8080") +func (s *StreamableHTTPServer) Start(addr string) error { + s.mu.Lock() + if s.httpServer == nil { + mux := http.NewServeMux() + mux.Handle(s.endpointPath, s) + s.httpServer = &http.Server{ + Addr: addr, + Handler: mux, + } + } else { + if s.httpServer.Addr == "" { + s.httpServer.Addr = addr + } else if s.httpServer.Addr != addr { + return fmt.Errorf("conflicting listen address: WithStreamableHTTPServer(%q) vs Start(%q)", s.httpServer.Addr, addr) + } + } + srv := s.httpServer + s.mu.Unlock() + + if s.tlsCertFile != "" || s.tlsKeyFile != "" { + if s.tlsCertFile == "" || s.tlsKeyFile == "" { + return fmt.Errorf("both TLS cert and key must be provided") + } + if _, err := os.Stat(s.tlsCertFile); err != nil { + return fmt.Errorf("failed to find TLS certificate file: %w", err) + } + if _, err := os.Stat(s.tlsKeyFile); err != nil { + return fmt.Errorf("failed to find TLS key file: %w", err) + } + return srv.ListenAndServeTLS(s.tlsCertFile, s.tlsKeyFile) + } + + return srv.ListenAndServe() +} + +// Shutdown gracefully stops the server, closing all active sessions +// and shutting down the HTTP server. +func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { + + // shutdown the server if needed (may use as a http.Handler) + s.mu.RLock() + srv := s.httpServer + s.mu.RUnlock() + if srv != nil { + return srv.Shutdown(ctx) + } + return nil +} + +// --- internal methods --- + +func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { + // post request carry request/notification message + + // Check content type + contentType := r.Header.Get("Content-Type") + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil || mediaType != "application/json" { + http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest) + return + } + + // Check the request body is valid json, meanwhile, get the request Method + rawData, err := io.ReadAll(r.Body) + if err != nil { + s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err)) + return + } + // First, try to parse as a response (sampling responses don't have a method field) + var jsonMessage struct { + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Method mcp.MCPMethod `json:"method,omitempty"` + } + if err := json.Unmarshal(rawData, &jsonMessage); err != nil { + s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json") + return + } + + // Check if this is a sampling response (has result/error but no method) + isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && + (jsonMessage.Result != nil || jsonMessage.Error != nil) + + isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize + + // Handle sampling responses separately + if isSamplingResponse { + if err := s.handleSamplingResponse(w, r, jsonMessage); err != nil { + s.logger.Errorf("Failed to handle sampling response: %v", err) + http.Error(w, "Failed to handle sampling response", http.StatusInternalServerError) + } + return + } + + // Prepare the session for the mcp server + // The session is ephemeral. Its life is the same as the request. It's only created + // for interaction with the mcp server. + var sessionID string + if isInitializeRequest { + // generate a new one for initialize request + sessionID = s.sessionIdManager.Generate() + } else { + // Get session ID from header. + // Stateful servers need the client to carry the session ID. + sessionID = r.Header.Get(HeaderKeySessionID) + isTerminated, err := s.sessionIdManager.Validate(sessionID) + if err != nil { + http.Error(w, "Invalid session ID", http.StatusBadRequest) + return + } + if isTerminated { + http.Error(w, "Session terminated", http.StatusNotFound) + return + } + } + + session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) + + // Set the client context before handling the message + ctx := s.server.WithContext(r.Context(), session) + if s.contextFunc != nil { + ctx = s.contextFunc(ctx, r) + } + + // handle potential notifications + mu := sync.Mutex{} + upgradedHeader := false + done := make(chan struct{}) + + ctx = context.WithValue(ctx, requestHeader, r.Header) + go func() { + for { + select { + case nt := <-session.notificationChannel: + func() { + mu.Lock() + defer mu.Unlock() + // if the done chan is closed, as the request is terminated, just return + select { + case <-done: + return + default: + } + defer func() { + flusher, ok := w.(http.Flusher) + if ok { + flusher.Flush() + } + }() + + // if there's notifications, upgradedHeader to SSE response + if !upgradedHeader { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + upgradedHeader = true + } + err := writeSSEEvent(w, nt) + if err != nil { + s.logger.Errorf("Failed to write SSE event: %v", err) + return + } + }() + case <-done: + return + case <-ctx.Done(): + return + } + } + }() + + // Process message through MCPServer + response := s.server.HandleMessage(ctx, rawData) + if response == nil { + // For notifications, just send 202 Accepted with no body + w.WriteHeader(http.StatusAccepted) + return + } + + // Write response + mu.Lock() + defer mu.Unlock() + // close the done chan before unlock + defer close(done) + if ctx.Err() != nil { + return + } + // If client-server communication already upgraded to SSE stream + if session.upgradeToSSE.Load() { + if !upgradedHeader { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("Cache-Control", "no-cache") + w.WriteHeader(http.StatusOK) + upgradedHeader = true + } + if err := writeSSEEvent(w, response); err != nil { + s.logger.Errorf("Failed to write final SSE response event: %v", err) + } + } else { + w.Header().Set("Content-Type", "application/json") + if isInitializeRequest && sessionID != "" { + // send the session ID back to the client + w.Header().Set(HeaderKeySessionID, sessionID) + } + w.WriteHeader(http.StatusOK) + err := json.NewEncoder(w).Encode(response) + if err != nil { + s.logger.Errorf("Failed to write response: %v", err) + } + } +} + +func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) { + // get request is for listening to notifications + // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server + + sessionID := r.Header.Get(HeaderKeySessionID) + // the specification didn't say we should validate the session id + + if sessionID == "" { + // It's a stateless server, + // but the MCP server requires a unique ID for registering, so we use a random one + sessionID = uuid.New().String() + } + + session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) + if err := s.server.RegisterSession(r.Context(), session); err != nil { + http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) + return + } + defer s.server.UnregisterSession(r.Context(), sessionID) + + // Register session for sampling response delivery + s.activeSessions.Store(sessionID, session) + defer s.activeSessions.Delete(sessionID) + + // Set the client context before handling the message + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + flusher.Flush() + + // Start notification handler for this session + done := make(chan struct{}) + defer close(done) + writeChan := make(chan any, 16) + + go func() { + for { + select { + case nt := <-session.notificationChannel: + select { + case writeChan <- &nt: + case <-done: + return + } + case samplingReq := <-session.samplingRequestChan: + // Send sampling request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(samplingReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + Params: samplingReq.request.CreateMessageParams, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } + case <-done: + return + } + } + }() + + if s.listenHeartbeatInterval > 0 { + // heartbeat to keep the connection alive + go func() { + ticker := time.NewTicker(s.listenHeartbeatInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + message := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(s.nextRequestID(sessionID)), + Request: mcp.Request{ + Method: "ping", + }, + } + select { + case writeChan <- message: + case <-done: + return + } + case <-done: + return + } + } + }() + } + + // Keep the connection open until the client disconnects + // + // There's will a Available() check when handler ends, and it maybe race with Flush(), + // so we use a separate channel to send the data, inteading of flushing directly in other goroutine. + for { + select { + case data := <-writeChan: + if data == nil { + continue + } + if err := writeSSEEvent(w, data); err != nil { + s.logger.Errorf("Failed to write SSE event: %v", err) + return + } + flusher.Flush() + case <-r.Context().Done(): + return + } + } +} + +func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) { + // delete request terminate the session + sessionID := r.Header.Get(HeaderKeySessionID) + notAllowed, err := s.sessionIdManager.Terminate(sessionID) + if err != nil { + http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError) + return + } + if notAllowed { + http.Error(w, "Session termination not allowed", http.StatusMethodNotAllowed) + return + } + + // remove the session relateddata from the sessionToolsStore + s.sessionTools.delete(sessionID) + s.sessionLogLevels.delete(sessionID) + // remove current session's requstID information + s.sessionRequestIDs.Delete(sessionID) + + w.WriteHeader(http.StatusOK) +} + +func writeSSEEvent(w io.Writer, data any) error { + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal data: %w", err) + } + _, err = fmt.Fprintf(w, "event: message\ndata: %s\n\n", jsonData) + if err != nil { + return fmt.Errorf("failed to write SSE event: %w", err) + } + return nil +} + +// handleSamplingResponse processes incoming sampling responses from clients +func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *http.Request, responseMessage struct { + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Method mcp.MCPMethod `json:"method,omitempty"` +}) error { + // Get session ID from header + sessionID := r.Header.Get(HeaderKeySessionID) + if sessionID == "" { + http.Error(w, "Missing session ID for sampling response", http.StatusBadRequest) + return fmt.Errorf("missing session ID") + } + + // Validate session + isTerminated, err := s.sessionIdManager.Validate(sessionID) + if err != nil { + http.Error(w, "Invalid session ID", http.StatusBadRequest) + return err + } + if isTerminated { + http.Error(w, "Session terminated", http.StatusNotFound) + return fmt.Errorf("session terminated") + } + + // Parse the request ID + var requestID int64 + if err := json.Unmarshal(responseMessage.ID, &requestID); err != nil { + http.Error(w, "Invalid request ID in sampling response", http.StatusBadRequest) + return err + } + + // Create the sampling response item + response := samplingResponseItem{ + requestID: requestID, + } + + // Parse result or error + if responseMessage.Error != nil { + // Parse error + var jsonrpcError struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(responseMessage.Error, &jsonrpcError); err != nil { + response.err = fmt.Errorf("failed to parse error: %v", err) + } else { + response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message) + } + } else if responseMessage.Result != nil { + // Parse result + var result mcp.CreateMessageResult + if err := json.Unmarshal(responseMessage.Result, &result); err != nil { + response.err = fmt.Errorf("failed to parse sampling result: %v", err) + } else { + response.result = &result + } + } else { + response.err = fmt.Errorf("sampling response has neither result nor error") + } + + // Find the corresponding session and deliver the response + // The response is delivered to the specific session identified by sessionID + if err := s.deliverSamplingResponse(sessionID, response); err != nil { + s.logger.Errorf("Failed to deliver sampling response: %v", err) + http.Error(w, "Failed to deliver response", http.StatusInternalServerError) + return err + } + + // Acknowledge receipt + w.WriteHeader(http.StatusOK) + return nil +} + +// deliverSamplingResponse delivers a sampling response to the appropriate session +func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) error { + // Look up the active session + sessionInterface, ok := s.activeSessions.Load(sessionID) + if !ok { + return fmt.Errorf("no active session found for session %s", sessionID) + } + + session, ok := sessionInterface.(*streamableHttpSession) + if !ok { + return fmt.Errorf("invalid session type for session %s", sessionID) + } + + // Look up the dedicated response channel for this specific request + responseChannelInterface, exists := session.samplingRequests.Load(response.requestID) + if !exists { + return fmt.Errorf("no pending request found for session %s, request %d", sessionID, response.requestID) + } + + responseChan, ok := responseChannelInterface.(chan samplingResponseItem) + if !ok { + return fmt.Errorf("invalid response channel type for session %s, request %d", sessionID, response.requestID) + } + + // Attempt to deliver the response with timeout to prevent indefinite blocking + select { + case responseChan <- response: + s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID) + return nil + default: + return fmt.Errorf("failed to deliver sampling response for session %s, request %d: channel full or blocked", sessionID, response.requestID) + } +} + +// writeJSONRPCError writes a JSON-RPC error response with the given error details. +func (s *StreamableHTTPServer) writeJSONRPCError( + w http.ResponseWriter, + id any, + code int, + message string, +) { + response := createErrorResponse(id, code, message) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + err := json.NewEncoder(w).Encode(response) + if err != nil { + s.logger.Errorf("Failed to write JSONRPCError: %v", err) + } +} + +// nextRequestID gets the next incrementing requestID for the current session +func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 { + actual, _ := s.sessionRequestIDs.LoadOrStore(sessionID, new(atomic.Int64)) + counter := actual.(*atomic.Int64) + return counter.Add(1) +} + +// --- session --- +type sessionLogLevelsStore struct { + mu sync.RWMutex + logs map[string]mcp.LoggingLevel +} + +func newSessionLogLevelsStore() *sessionLogLevelsStore { + return &sessionLogLevelsStore{ + logs: make(map[string]mcp.LoggingLevel), + } +} + +func (s *sessionLogLevelsStore) get(sessionID string) mcp.LoggingLevel { + s.mu.RLock() + defer s.mu.RUnlock() + val, ok := s.logs[sessionID] + if !ok { + return mcp.LoggingLevelError + } + return val +} + +func (s *sessionLogLevelsStore) set(sessionID string, level mcp.LoggingLevel) { + s.mu.Lock() + defer s.mu.Unlock() + s.logs[sessionID] = level +} + +func (s *sessionLogLevelsStore) delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.logs, sessionID) +} + +type sessionToolsStore struct { + mu sync.RWMutex + tools map[string]map[string]ServerTool // sessionID -> toolName -> tool +} + +func newSessionToolsStore() *sessionToolsStore { + return &sessionToolsStore{ + tools: make(map[string]map[string]ServerTool), + } +} + +func (s *sessionToolsStore) get(sessionID string) map[string]ServerTool { + s.mu.RLock() + defer s.mu.RUnlock() + return s.tools[sessionID] +} + +func (s *sessionToolsStore) set(sessionID string, tools map[string]ServerTool) { + s.mu.Lock() + defer s.mu.Unlock() + s.tools[sessionID] = tools +} + +func (s *sessionToolsStore) delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.tools, sessionID) +} + +// Sampling support types for HTTP transport +type samplingRequestItem struct { + requestID int64 + request mcp.CreateMessageRequest + response chan samplingResponseItem +} + +type samplingResponseItem struct { + requestID int64 + result *mcp.CreateMessageResult + err error +} + +// streamableHttpSession is a session for streamable-http transport +// When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. +// When in GET handlers(listening), it's a real session, and will be registered in the MCP server. +type streamableHttpSession struct { + sessionID string + notificationChannel chan mcp.JSONRPCNotification // server -> client notifications + tools *sessionToolsStore + upgradeToSSE atomic.Bool + logLevels *sessionLogLevelsStore + + // Sampling support for bidirectional communication + samplingRequestChan chan samplingRequestItem // server -> client sampling requests + samplingRequests sync.Map // requestID -> pending sampling request context + requestIDCounter atomic.Int64 // for generating unique request IDs +} + +func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { + s := &streamableHttpSession{ + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + logLevels: levels, + samplingRequestChan: make(chan samplingRequestItem, 10), + } + return s +} + +func (s *streamableHttpSession) SessionID() string { + return s.sessionID +} + +func (s *streamableHttpSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notificationChannel +} + +func (s *streamableHttpSession) Initialize() { + // do nothing + // the session is ephemeral, no real initialized action needed +} + +func (s *streamableHttpSession) Initialized() bool { + // the session is ephemeral, no real initialized action needed + return true +} + +func (s *streamableHttpSession) SetLogLevel(level mcp.LoggingLevel) { + s.logLevels.set(s.sessionID, level) +} + +func (s *streamableHttpSession) GetLogLevel() mcp.LoggingLevel { + return s.logLevels.get(s.sessionID) +} + +var _ ClientSession = (*streamableHttpSession)(nil) + +func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool { + return s.tools.get(s.sessionID) +} + +func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) { + s.tools.set(s.sessionID, tools) +} + +var ( + _ SessionWithTools = (*streamableHttpSession)(nil) + _ SessionWithLogging = (*streamableHttpSession)(nil) +) + +func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { + s.upgradeToSSE.Store(true) +} + +var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) + +// RequestSampling implements SessionWithSampling interface for HTTP transport +func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan samplingResponseItem, 1) + + // Create the sampling request item + samplingRequest := samplingRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.samplingRequests.Store(requestID, responseChan) + defer s.samplingRequests.Delete(requestID) + + // Send the sampling request via the channel (non-blocking) + select { + case s.samplingRequestChan <- samplingRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("sampling request queue is full - server overloaded") + } + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +var _ SessionWithSampling = (*streamableHttpSession)(nil) + +// --- session id manager --- + +type SessionIdManager interface { + Generate() string + // Validate checks if a session ID is valid and not terminated. + // Returns isTerminated=true if the ID is valid but belongs to a terminated session. + // Returns err!=nil if the ID format is invalid or lookup failed. + Validate(sessionID string) (isTerminated bool, err error) + // Terminate marks a session ID as terminated. + // Returns isNotAllowed=true if the server policy prevents client termination. + // Returns err!=nil if the ID is invalid or termination failed. + Terminate(sessionID string) (isNotAllowed bool, err error) +} + +// StatelessSessionIdManager does nothing, which means it has no session management, which is stateless. +type StatelessSessionIdManager struct{} + +func (s *StatelessSessionIdManager) Generate() string { + return "" +} + +func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { + // In stateless mode, ignore session IDs completely - don't validate or reject them + return false, nil +} + +func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { + return false, nil +} + +// InsecureStatefulSessionIdManager generate id with uuid +// It won't validate the id indeed, so it could be fake. +// For more secure session id, use a more complex generator, like a JWT. +type InsecureStatefulSessionIdManager struct{} + +const idPrefix = "mcp-session-" + +func (s *InsecureStatefulSessionIdManager) Generate() string { + return idPrefix + uuid.New().String() +} + +func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { + // validate the session id is a valid uuid + if !strings.HasPrefix(sessionID, idPrefix) { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + if _, err := uuid.Parse(sessionID[len(idPrefix):]); err != nil { + return false, fmt.Errorf("invalid session id: %s", sessionID) + } + return false, nil +} + +func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { + return false, nil +} + +// NewTestStreamableHTTPServer creates a test server for testing purposes +func NewTestStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *httptest.Server { + sseServer := NewStreamableHTTPServer(server, opts...) + testServer := httptest.NewServer(sseServer) + return testServer +} diff --git a/vendor/github.com/mark3labs/mcp-go/util/logger.go b/vendor/github.com/mark3labs/mcp-go/util/logger.go new file mode 100644 index 0000000000..8d7555ce35 --- /dev/null +++ b/vendor/github.com/mark3labs/mcp-go/util/logger.go @@ -0,0 +1,33 @@ +package util + +import ( + "log" +) + +// Logger defines a minimal logging interface +type Logger interface { + Infof(format string, v ...any) + Errorf(format string, v ...any) +} + +// --- Standard Library Logger Wrapper --- + +// DefaultStdLogger implements Logger using the standard library's log.Logger. +func DefaultLogger() Logger { + return &stdLogger{ + logger: log.Default(), + } +} + +// stdLogger wraps the standard library's log.Logger. +type stdLogger struct { + logger *log.Logger +} + +func (l *stdLogger) Infof(format string, v ...any) { + l.logger.Printf("INFO: "+format, v...) +} + +func (l *stdLogger) Errorf(format string, v ...any) { + l.logger.Printf("ERROR: "+format, v...) +} diff --git a/vendor/github.com/spf13/cast/.gitignore b/vendor/github.com/spf13/cast/.gitignore new file mode 100644 index 0000000000..53053a8ac5 --- /dev/null +++ b/vendor/github.com/spf13/cast/.gitignore @@ -0,0 +1,25 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test + +*.bench diff --git a/vendor/github.com/spf13/cast/LICENSE b/vendor/github.com/spf13/cast/LICENSE new file mode 100644 index 0000000000..4527efb9c0 --- /dev/null +++ b/vendor/github.com/spf13/cast/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014 Steve Francia + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/vendor/github.com/spf13/cast/Makefile b/vendor/github.com/spf13/cast/Makefile new file mode 100644 index 0000000000..f01a5dbb6e --- /dev/null +++ b/vendor/github.com/spf13/cast/Makefile @@ -0,0 +1,40 @@ +GOVERSION := $(shell go version | cut -d ' ' -f 3 | cut -d '.' -f 2) + +.PHONY: check fmt lint test test-race vet test-cover-html help +.DEFAULT_GOAL := help + +check: test-race fmt vet lint ## Run tests and linters + +test: ## Run tests + go test ./... + +test-race: ## Run tests with race detector + go test -race ./... + +fmt: ## Run gofmt linter +ifeq "$(GOVERSION)" "12" + @for d in `go list` ; do \ + if [ "`gofmt -l -s $$GOPATH/src/$$d | tee /dev/stderr`" ]; then \ + echo "^ improperly formatted go files" && echo && exit 1; \ + fi \ + done +endif + +lint: ## Run golint linter + @for d in `go list` ; do \ + if [ "`golint $$d | tee /dev/stderr`" ]; then \ + echo "^ golint errors!" && echo && exit 1; \ + fi \ + done + +vet: ## Run go vet linter + @if [ "`go vet | tee /dev/stderr`" ]; then \ + echo "^ go vet errors!" && echo && exit 1; \ + fi + +test-cover-html: ## Generate test coverage report + go test -coverprofile=coverage.out -covermode=count + go tool cover -func=coverage.out + +help: + @grep -E '^[a-zA-Z0-9_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' diff --git a/vendor/github.com/spf13/cast/README.md b/vendor/github.com/spf13/cast/README.md new file mode 100644 index 0000000000..1be666a456 --- /dev/null +++ b/vendor/github.com/spf13/cast/README.md @@ -0,0 +1,75 @@ +# cast + +[![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/spf13/cast/test.yaml?branch=master&style=flat-square)](https://github.com/spf13/cast/actions/workflows/test.yaml) +[![PkgGoDev](https://pkg.go.dev/badge/mod/github.com/spf13/cast)](https://pkg.go.dev/mod/github.com/spf13/cast) +![Go Version](https://img.shields.io/badge/go%20version-%3E=1.16-61CFDD.svg?style=flat-square) +[![Go Report Card](https://goreportcard.com/badge/github.com/spf13/cast?style=flat-square)](https://goreportcard.com/report/github.com/spf13/cast) + +Easy and safe casting from one type to another in Go + +Don’t Panic! ... Cast + +## What is Cast? + +Cast is a library to convert between different go types in a consistent and easy way. + +Cast provides simple functions to easily convert a number to a string, an +interface into a bool, etc. Cast does this intelligently when an obvious +conversion is possible. It doesn’t make any attempts to guess what you meant, +for example you can only convert a string to an int when it is a string +representation of an int such as “8”. Cast was developed for use in +[Hugo](https://gohugo.io), a website engine which uses YAML, TOML or JSON +for meta data. + +## Why use Cast? + +When working with dynamic data in Go you often need to cast or convert the data +from one type into another. Cast goes beyond just using type assertion (though +it uses that when possible) to provide a very straightforward and convenient +library. + +If you are working with interfaces to handle things like dynamic content +you’ll need an easy way to convert an interface into a given type. This +is the library for you. + +If you are taking in data from YAML, TOML or JSON or other formats which lack +full types, then Cast is the library for you. + +## Usage + +Cast provides a handful of To_____ methods. These methods will always return +the desired type. **If input is provided that will not convert to that type, the +0 or nil value for that type will be returned**. + +Cast also provides identical methods To_____E. These return the same result as +the To_____ methods, plus an additional error which tells you if it successfully +converted. Using these methods you can tell the difference between when the +input matched the zero value or when the conversion failed and the zero value +was returned. + +The following examples are merely a sample of what is available. Please review +the code for a complete set. + +### Example ‘ToString’: + + cast.ToString("mayonegg") // "mayonegg" + cast.ToString(8) // "8" + cast.ToString(8.31) // "8.31" + cast.ToString([]byte("one time")) // "one time" + cast.ToString(nil) // "" + + var foo interface{} = "one more time" + cast.ToString(foo) // "one more time" + + +### Example ‘ToInt’: + + cast.ToInt(8) // 8 + cast.ToInt(8.31) // 8 + cast.ToInt("8") // 8 + cast.ToInt(true) // 1 + cast.ToInt(false) // 0 + + var eight interface{} = 8 + cast.ToInt(eight) // 8 + cast.ToInt(nil) // 0 diff --git a/vendor/github.com/spf13/cast/cast.go b/vendor/github.com/spf13/cast/cast.go new file mode 100644 index 0000000000..0cfe9418de --- /dev/null +++ b/vendor/github.com/spf13/cast/cast.go @@ -0,0 +1,176 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package cast provides easy and safe casting in Go. +package cast + +import "time" + +// ToBool casts an interface to a bool type. +func ToBool(i interface{}) bool { + v, _ := ToBoolE(i) + return v +} + +// ToTime casts an interface to a time.Time type. +func ToTime(i interface{}) time.Time { + v, _ := ToTimeE(i) + return v +} + +func ToTimeInDefaultLocation(i interface{}, location *time.Location) time.Time { + v, _ := ToTimeInDefaultLocationE(i, location) + return v +} + +// ToDuration casts an interface to a time.Duration type. +func ToDuration(i interface{}) time.Duration { + v, _ := ToDurationE(i) + return v +} + +// ToFloat64 casts an interface to a float64 type. +func ToFloat64(i interface{}) float64 { + v, _ := ToFloat64E(i) + return v +} + +// ToFloat32 casts an interface to a float32 type. +func ToFloat32(i interface{}) float32 { + v, _ := ToFloat32E(i) + return v +} + +// ToInt64 casts an interface to an int64 type. +func ToInt64(i interface{}) int64 { + v, _ := ToInt64E(i) + return v +} + +// ToInt32 casts an interface to an int32 type. +func ToInt32(i interface{}) int32 { + v, _ := ToInt32E(i) + return v +} + +// ToInt16 casts an interface to an int16 type. +func ToInt16(i interface{}) int16 { + v, _ := ToInt16E(i) + return v +} + +// ToInt8 casts an interface to an int8 type. +func ToInt8(i interface{}) int8 { + v, _ := ToInt8E(i) + return v +} + +// ToInt casts an interface to an int type. +func ToInt(i interface{}) int { + v, _ := ToIntE(i) + return v +} + +// ToUint casts an interface to a uint type. +func ToUint(i interface{}) uint { + v, _ := ToUintE(i) + return v +} + +// ToUint64 casts an interface to a uint64 type. +func ToUint64(i interface{}) uint64 { + v, _ := ToUint64E(i) + return v +} + +// ToUint32 casts an interface to a uint32 type. +func ToUint32(i interface{}) uint32 { + v, _ := ToUint32E(i) + return v +} + +// ToUint16 casts an interface to a uint16 type. +func ToUint16(i interface{}) uint16 { + v, _ := ToUint16E(i) + return v +} + +// ToUint8 casts an interface to a uint8 type. +func ToUint8(i interface{}) uint8 { + v, _ := ToUint8E(i) + return v +} + +// ToString casts an interface to a string type. +func ToString(i interface{}) string { + v, _ := ToStringE(i) + return v +} + +// ToStringMapString casts an interface to a map[string]string type. +func ToStringMapString(i interface{}) map[string]string { + v, _ := ToStringMapStringE(i) + return v +} + +// ToStringMapStringSlice casts an interface to a map[string][]string type. +func ToStringMapStringSlice(i interface{}) map[string][]string { + v, _ := ToStringMapStringSliceE(i) + return v +} + +// ToStringMapBool casts an interface to a map[string]bool type. +func ToStringMapBool(i interface{}) map[string]bool { + v, _ := ToStringMapBoolE(i) + return v +} + +// ToStringMapInt casts an interface to a map[string]int type. +func ToStringMapInt(i interface{}) map[string]int { + v, _ := ToStringMapIntE(i) + return v +} + +// ToStringMapInt64 casts an interface to a map[string]int64 type. +func ToStringMapInt64(i interface{}) map[string]int64 { + v, _ := ToStringMapInt64E(i) + return v +} + +// ToStringMap casts an interface to a map[string]interface{} type. +func ToStringMap(i interface{}) map[string]interface{} { + v, _ := ToStringMapE(i) + return v +} + +// ToSlice casts an interface to a []interface{} type. +func ToSlice(i interface{}) []interface{} { + v, _ := ToSliceE(i) + return v +} + +// ToBoolSlice casts an interface to a []bool type. +func ToBoolSlice(i interface{}) []bool { + v, _ := ToBoolSliceE(i) + return v +} + +// ToStringSlice casts an interface to a []string type. +func ToStringSlice(i interface{}) []string { + v, _ := ToStringSliceE(i) + return v +} + +// ToIntSlice casts an interface to a []int type. +func ToIntSlice(i interface{}) []int { + v, _ := ToIntSliceE(i) + return v +} + +// ToDurationSlice casts an interface to a []time.Duration type. +func ToDurationSlice(i interface{}) []time.Duration { + v, _ := ToDurationSliceE(i) + return v +} diff --git a/vendor/github.com/spf13/cast/caste.go b/vendor/github.com/spf13/cast/caste.go new file mode 100644 index 0000000000..4181a2e758 --- /dev/null +++ b/vendor/github.com/spf13/cast/caste.go @@ -0,0 +1,1510 @@ +// Copyright © 2014 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package cast + +import ( + "encoding/json" + "errors" + "fmt" + "html/template" + "reflect" + "strconv" + "strings" + "time" +) + +var errNegativeNotAllowed = errors.New("unable to cast negative value") + +type float64EProvider interface { + Float64() (float64, error) +} + +type float64Provider interface { + Float64() float64 +} + +// ToTimeE casts an interface to a time.Time type. +func ToTimeE(i interface{}) (tim time.Time, err error) { + return ToTimeInDefaultLocationE(i, time.UTC) +} + +// ToTimeInDefaultLocationE casts an empty interface to time.Time, +// interpreting inputs without a timezone to be in the given location, +// or the local timezone if nil. +func ToTimeInDefaultLocationE(i interface{}, location *time.Location) (tim time.Time, err error) { + i = indirect(i) + + switch v := i.(type) { + case time.Time: + return v, nil + case string: + return StringToDateInDefaultLocation(v, location) + case json.Number: + s, err1 := ToInt64E(v) + if err1 != nil { + return time.Time{}, fmt.Errorf("unable to cast %#v of type %T to Time", i, i) + } + return time.Unix(s, 0), nil + case int: + return time.Unix(int64(v), 0), nil + case int64: + return time.Unix(v, 0), nil + case int32: + return time.Unix(int64(v), 0), nil + case uint: + return time.Unix(int64(v), 0), nil + case uint64: + return time.Unix(int64(v), 0), nil + case uint32: + return time.Unix(int64(v), 0), nil + default: + return time.Time{}, fmt.Errorf("unable to cast %#v of type %T to Time", i, i) + } +} + +// ToDurationE casts an interface to a time.Duration type. +func ToDurationE(i interface{}) (d time.Duration, err error) { + i = indirect(i) + + switch s := i.(type) { + case time.Duration: + return s, nil + case int, int64, int32, int16, int8, uint, uint64, uint32, uint16, uint8: + d = time.Duration(ToInt64(s)) + return + case float32, float64: + d = time.Duration(ToFloat64(s)) + return + case string: + if strings.ContainsAny(s, "nsuµmh") { + d, err = time.ParseDuration(s) + } else { + d, err = time.ParseDuration(s + "ns") + } + return + case float64EProvider: + var v float64 + v, err = s.Float64() + d = time.Duration(v) + return + case float64Provider: + d = time.Duration(s.Float64()) + return + default: + err = fmt.Errorf("unable to cast %#v of type %T to Duration", i, i) + return + } +} + +// ToBoolE casts an interface to a bool type. +func ToBoolE(i interface{}) (bool, error) { + i = indirect(i) + + switch b := i.(type) { + case bool: + return b, nil + case nil: + return false, nil + case int: + return b != 0, nil + case int64: + return b != 0, nil + case int32: + return b != 0, nil + case int16: + return b != 0, nil + case int8: + return b != 0, nil + case uint: + return b != 0, nil + case uint64: + return b != 0, nil + case uint32: + return b != 0, nil + case uint16: + return b != 0, nil + case uint8: + return b != 0, nil + case float64: + return b != 0, nil + case float32: + return b != 0, nil + case time.Duration: + return b != 0, nil + case string: + return strconv.ParseBool(i.(string)) + case json.Number: + v, err := ToInt64E(b) + if err == nil { + return v != 0, nil + } + return false, fmt.Errorf("unable to cast %#v of type %T to bool", i, i) + default: + return false, fmt.Errorf("unable to cast %#v of type %T to bool", i, i) + } +} + +// ToFloat64E casts an interface to a float64 type. +func ToFloat64E(i interface{}) (float64, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + return float64(intv), nil + } + + switch s := i.(type) { + case float64: + return s, nil + case float32: + return float64(s), nil + case int64: + return float64(s), nil + case int32: + return float64(s), nil + case int16: + return float64(s), nil + case int8: + return float64(s), nil + case uint: + return float64(s), nil + case uint64: + return float64(s), nil + case uint32: + return float64(s), nil + case uint16: + return float64(s), nil + case uint8: + return float64(s), nil + case string: + v, err := strconv.ParseFloat(s, 64) + if err == nil { + return v, nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to float64", i, i) + case float64EProvider: + v, err := s.Float64() + if err == nil { + return v, nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to float64", i, i) + case float64Provider: + return s.Float64(), nil + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to float64", i, i) + } +} + +// ToFloat32E casts an interface to a float32 type. +func ToFloat32E(i interface{}) (float32, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + return float32(intv), nil + } + + switch s := i.(type) { + case float64: + return float32(s), nil + case float32: + return s, nil + case int64: + return float32(s), nil + case int32: + return float32(s), nil + case int16: + return float32(s), nil + case int8: + return float32(s), nil + case uint: + return float32(s), nil + case uint64: + return float32(s), nil + case uint32: + return float32(s), nil + case uint16: + return float32(s), nil + case uint8: + return float32(s), nil + case string: + v, err := strconv.ParseFloat(s, 32) + if err == nil { + return float32(v), nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to float32", i, i) + case float64EProvider: + v, err := s.Float64() + if err == nil { + return float32(v), nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to float32", i, i) + case float64Provider: + return float32(s.Float64()), nil + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to float32", i, i) + } +} + +// ToInt64E casts an interface to an int64 type. +func ToInt64E(i interface{}) (int64, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + return int64(intv), nil + } + + switch s := i.(type) { + case int64: + return s, nil + case int32: + return int64(s), nil + case int16: + return int64(s), nil + case int8: + return int64(s), nil + case uint: + return int64(s), nil + case uint64: + return int64(s), nil + case uint32: + return int64(s), nil + case uint16: + return int64(s), nil + case uint8: + return int64(s), nil + case float64: + return int64(s), nil + case float32: + return int64(s), nil + case string: + v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0) + if err == nil { + return v, nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to int64", i, i) + case json.Number: + return ToInt64E(string(s)) + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to int64", i, i) + } +} + +// ToInt32E casts an interface to an int32 type. +func ToInt32E(i interface{}) (int32, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + return int32(intv), nil + } + + switch s := i.(type) { + case int64: + return int32(s), nil + case int32: + return s, nil + case int16: + return int32(s), nil + case int8: + return int32(s), nil + case uint: + return int32(s), nil + case uint64: + return int32(s), nil + case uint32: + return int32(s), nil + case uint16: + return int32(s), nil + case uint8: + return int32(s), nil + case float64: + return int32(s), nil + case float32: + return int32(s), nil + case string: + v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0) + if err == nil { + return int32(v), nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to int32", i, i) + case json.Number: + return ToInt32E(string(s)) + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to int32", i, i) + } +} + +// ToInt16E casts an interface to an int16 type. +func ToInt16E(i interface{}) (int16, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + return int16(intv), nil + } + + switch s := i.(type) { + case int64: + return int16(s), nil + case int32: + return int16(s), nil + case int16: + return s, nil + case int8: + return int16(s), nil + case uint: + return int16(s), nil + case uint64: + return int16(s), nil + case uint32: + return int16(s), nil + case uint16: + return int16(s), nil + case uint8: + return int16(s), nil + case float64: + return int16(s), nil + case float32: + return int16(s), nil + case string: + v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0) + if err == nil { + return int16(v), nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to int16", i, i) + case json.Number: + return ToInt16E(string(s)) + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to int16", i, i) + } +} + +// ToInt8E casts an interface to an int8 type. +func ToInt8E(i interface{}) (int8, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + return int8(intv), nil + } + + switch s := i.(type) { + case int64: + return int8(s), nil + case int32: + return int8(s), nil + case int16: + return int8(s), nil + case int8: + return s, nil + case uint: + return int8(s), nil + case uint64: + return int8(s), nil + case uint32: + return int8(s), nil + case uint16: + return int8(s), nil + case uint8: + return int8(s), nil + case float64: + return int8(s), nil + case float32: + return int8(s), nil + case string: + v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0) + if err == nil { + return int8(v), nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to int8", i, i) + case json.Number: + return ToInt8E(string(s)) + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to int8", i, i) + } +} + +// ToIntE casts an interface to an int type. +func ToIntE(i interface{}) (int, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + return intv, nil + } + + switch s := i.(type) { + case int64: + return int(s), nil + case int32: + return int(s), nil + case int16: + return int(s), nil + case int8: + return int(s), nil + case uint: + return int(s), nil + case uint64: + return int(s), nil + case uint32: + return int(s), nil + case uint16: + return int(s), nil + case uint8: + return int(s), nil + case float64: + return int(s), nil + case float32: + return int(s), nil + case string: + v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0) + if err == nil { + return int(v), nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to int64", i, i) + case json.Number: + return ToIntE(string(s)) + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to int", i, i) + } +} + +// ToUintE casts an interface to a uint type. +func ToUintE(i interface{}) (uint, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + if intv < 0 { + return 0, errNegativeNotAllowed + } + return uint(intv), nil + } + + switch s := i.(type) { + case string: + v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0) + if err == nil { + if v < 0 { + return 0, errNegativeNotAllowed + } + return uint(v), nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to uint", i, i) + case json.Number: + return ToUintE(string(s)) + case int64: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint(s), nil + case int32: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint(s), nil + case int16: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint(s), nil + case int8: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint(s), nil + case uint: + return s, nil + case uint64: + return uint(s), nil + case uint32: + return uint(s), nil + case uint16: + return uint(s), nil + case uint8: + return uint(s), nil + case float64: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint(s), nil + case float32: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint(s), nil + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to uint", i, i) + } +} + +// ToUint64E casts an interface to a uint64 type. +func ToUint64E(i interface{}) (uint64, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + if intv < 0 { + return 0, errNegativeNotAllowed + } + return uint64(intv), nil + } + + switch s := i.(type) { + case string: + v, err := strconv.ParseUint(trimZeroDecimal(s), 0, 0) + if err == nil { + if v < 0 { + return 0, errNegativeNotAllowed + } + return v, nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to uint64", i, i) + case json.Number: + return ToUint64E(string(s)) + case int64: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint64(s), nil + case int32: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint64(s), nil + case int16: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint64(s), nil + case int8: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint64(s), nil + case uint: + return uint64(s), nil + case uint64: + return s, nil + case uint32: + return uint64(s), nil + case uint16: + return uint64(s), nil + case uint8: + return uint64(s), nil + case float32: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint64(s), nil + case float64: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint64(s), nil + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to uint64", i, i) + } +} + +// ToUint32E casts an interface to a uint32 type. +func ToUint32E(i interface{}) (uint32, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + if intv < 0 { + return 0, errNegativeNotAllowed + } + return uint32(intv), nil + } + + switch s := i.(type) { + case string: + v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0) + if err == nil { + if v < 0 { + return 0, errNegativeNotAllowed + } + return uint32(v), nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to uint32", i, i) + case json.Number: + return ToUint32E(string(s)) + case int64: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint32(s), nil + case int32: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint32(s), nil + case int16: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint32(s), nil + case int8: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint32(s), nil + case uint: + return uint32(s), nil + case uint64: + return uint32(s), nil + case uint32: + return s, nil + case uint16: + return uint32(s), nil + case uint8: + return uint32(s), nil + case float64: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint32(s), nil + case float32: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint32(s), nil + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to uint32", i, i) + } +} + +// ToUint16E casts an interface to a uint16 type. +func ToUint16E(i interface{}) (uint16, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + if intv < 0 { + return 0, errNegativeNotAllowed + } + return uint16(intv), nil + } + + switch s := i.(type) { + case string: + v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0) + if err == nil { + if v < 0 { + return 0, errNegativeNotAllowed + } + return uint16(v), nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to uint16", i, i) + case json.Number: + return ToUint16E(string(s)) + case int64: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint16(s), nil + case int32: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint16(s), nil + case int16: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint16(s), nil + case int8: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint16(s), nil + case uint: + return uint16(s), nil + case uint64: + return uint16(s), nil + case uint32: + return uint16(s), nil + case uint16: + return s, nil + case uint8: + return uint16(s), nil + case float64: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint16(s), nil + case float32: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint16(s), nil + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to uint16", i, i) + } +} + +// ToUint8E casts an interface to a uint type. +func ToUint8E(i interface{}) (uint8, error) { + i = indirect(i) + + intv, ok := toInt(i) + if ok { + if intv < 0 { + return 0, errNegativeNotAllowed + } + return uint8(intv), nil + } + + switch s := i.(type) { + case string: + v, err := strconv.ParseInt(trimZeroDecimal(s), 0, 0) + if err == nil { + if v < 0 { + return 0, errNegativeNotAllowed + } + return uint8(v), nil + } + return 0, fmt.Errorf("unable to cast %#v of type %T to uint8", i, i) + case json.Number: + return ToUint8E(string(s)) + case int64: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint8(s), nil + case int32: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint8(s), nil + case int16: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint8(s), nil + case int8: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint8(s), nil + case uint: + return uint8(s), nil + case uint64: + return uint8(s), nil + case uint32: + return uint8(s), nil + case uint16: + return uint8(s), nil + case uint8: + return s, nil + case float64: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint8(s), nil + case float32: + if s < 0 { + return 0, errNegativeNotAllowed + } + return uint8(s), nil + case bool: + if s { + return 1, nil + } + return 0, nil + case nil: + return 0, nil + default: + return 0, fmt.Errorf("unable to cast %#v of type %T to uint8", i, i) + } +} + +// From html/template/content.go +// Copyright 2011 The Go Authors. All rights reserved. +// indirect returns the value, after dereferencing as many times +// as necessary to reach the base type (or nil). +func indirect(a interface{}) interface{} { + if a == nil { + return nil + } + if t := reflect.TypeOf(a); t.Kind() != reflect.Ptr { + // Avoid creating a reflect.Value if it's not a pointer. + return a + } + v := reflect.ValueOf(a) + for v.Kind() == reflect.Ptr && !v.IsNil() { + v = v.Elem() + } + return v.Interface() +} + +// From html/template/content.go +// Copyright 2011 The Go Authors. All rights reserved. +// indirectToStringerOrError returns the value, after dereferencing as many times +// as necessary to reach the base type (or nil) or an implementation of fmt.Stringer +// or error, +func indirectToStringerOrError(a interface{}) interface{} { + if a == nil { + return nil + } + + errorType := reflect.TypeOf((*error)(nil)).Elem() + fmtStringerType := reflect.TypeOf((*fmt.Stringer)(nil)).Elem() + + v := reflect.ValueOf(a) + for !v.Type().Implements(fmtStringerType) && !v.Type().Implements(errorType) && v.Kind() == reflect.Ptr && !v.IsNil() { + v = v.Elem() + } + return v.Interface() +} + +// ToStringE casts an interface to a string type. +func ToStringE(i interface{}) (string, error) { + i = indirectToStringerOrError(i) + + switch s := i.(type) { + case string: + return s, nil + case bool: + return strconv.FormatBool(s), nil + case float64: + return strconv.FormatFloat(s, 'f', -1, 64), nil + case float32: + return strconv.FormatFloat(float64(s), 'f', -1, 32), nil + case int: + return strconv.Itoa(s), nil + case int64: + return strconv.FormatInt(s, 10), nil + case int32: + return strconv.Itoa(int(s)), nil + case int16: + return strconv.FormatInt(int64(s), 10), nil + case int8: + return strconv.FormatInt(int64(s), 10), nil + case uint: + return strconv.FormatUint(uint64(s), 10), nil + case uint64: + return strconv.FormatUint(uint64(s), 10), nil + case uint32: + return strconv.FormatUint(uint64(s), 10), nil + case uint16: + return strconv.FormatUint(uint64(s), 10), nil + case uint8: + return strconv.FormatUint(uint64(s), 10), nil + case json.Number: + return s.String(), nil + case []byte: + return string(s), nil + case template.HTML: + return string(s), nil + case template.URL: + return string(s), nil + case template.JS: + return string(s), nil + case template.CSS: + return string(s), nil + case template.HTMLAttr: + return string(s), nil + case nil: + return "", nil + case fmt.Stringer: + return s.String(), nil + case error: + return s.Error(), nil + default: + return "", fmt.Errorf("unable to cast %#v of type %T to string", i, i) + } +} + +// ToStringMapStringE casts an interface to a map[string]string type. +func ToStringMapStringE(i interface{}) (map[string]string, error) { + m := map[string]string{} + + switch v := i.(type) { + case map[string]string: + return v, nil + case map[string]interface{}: + for k, val := range v { + m[ToString(k)] = ToString(val) + } + return m, nil + case map[interface{}]string: + for k, val := range v { + m[ToString(k)] = ToString(val) + } + return m, nil + case map[interface{}]interface{}: + for k, val := range v { + m[ToString(k)] = ToString(val) + } + return m, nil + case string: + err := jsonStringToObject(v, &m) + return m, err + default: + return m, fmt.Errorf("unable to cast %#v of type %T to map[string]string", i, i) + } +} + +// ToStringMapStringSliceE casts an interface to a map[string][]string type. +func ToStringMapStringSliceE(i interface{}) (map[string][]string, error) { + m := map[string][]string{} + + switch v := i.(type) { + case map[string][]string: + return v, nil + case map[string][]interface{}: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[string]string: + for k, val := range v { + m[ToString(k)] = []string{val} + } + case map[string]interface{}: + for k, val := range v { + switch vt := val.(type) { + case []interface{}: + m[ToString(k)] = ToStringSlice(vt) + case []string: + m[ToString(k)] = vt + default: + m[ToString(k)] = []string{ToString(val)} + } + } + return m, nil + case map[interface{}][]string: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[interface{}]string: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[interface{}][]interface{}: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[interface{}]interface{}: + for k, val := range v { + key, err := ToStringE(k) + if err != nil { + return m, fmt.Errorf("unable to cast %#v of type %T to map[string][]string", i, i) + } + value, err := ToStringSliceE(val) + if err != nil { + return m, fmt.Errorf("unable to cast %#v of type %T to map[string][]string", i, i) + } + m[key] = value + } + case string: + err := jsonStringToObject(v, &m) + return m, err + default: + return m, fmt.Errorf("unable to cast %#v of type %T to map[string][]string", i, i) + } + return m, nil +} + +// ToStringMapBoolE casts an interface to a map[string]bool type. +func ToStringMapBoolE(i interface{}) (map[string]bool, error) { + m := map[string]bool{} + + switch v := i.(type) { + case map[interface{}]interface{}: + for k, val := range v { + m[ToString(k)] = ToBool(val) + } + return m, nil + case map[string]interface{}: + for k, val := range v { + m[ToString(k)] = ToBool(val) + } + return m, nil + case map[string]bool: + return v, nil + case string: + err := jsonStringToObject(v, &m) + return m, err + default: + return m, fmt.Errorf("unable to cast %#v of type %T to map[string]bool", i, i) + } +} + +// ToStringMapE casts an interface to a map[string]interface{} type. +func ToStringMapE(i interface{}) (map[string]interface{}, error) { + m := map[string]interface{}{} + + switch v := i.(type) { + case map[interface{}]interface{}: + for k, val := range v { + m[ToString(k)] = val + } + return m, nil + case map[string]interface{}: + return v, nil + case string: + err := jsonStringToObject(v, &m) + return m, err + default: + return m, fmt.Errorf("unable to cast %#v of type %T to map[string]interface{}", i, i) + } +} + +// ToStringMapIntE casts an interface to a map[string]int{} type. +func ToStringMapIntE(i interface{}) (map[string]int, error) { + m := map[string]int{} + if i == nil { + return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int", i, i) + } + + switch v := i.(type) { + case map[interface{}]interface{}: + for k, val := range v { + m[ToString(k)] = ToInt(val) + } + return m, nil + case map[string]interface{}: + for k, val := range v { + m[k] = ToInt(val) + } + return m, nil + case map[string]int: + return v, nil + case string: + err := jsonStringToObject(v, &m) + return m, err + } + + if reflect.TypeOf(i).Kind() != reflect.Map { + return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int", i, i) + } + + mVal := reflect.ValueOf(m) + v := reflect.ValueOf(i) + for _, keyVal := range v.MapKeys() { + val, err := ToIntE(v.MapIndex(keyVal).Interface()) + if err != nil { + return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int", i, i) + } + mVal.SetMapIndex(keyVal, reflect.ValueOf(val)) + } + return m, nil +} + +// ToStringMapInt64E casts an interface to a map[string]int64{} type. +func ToStringMapInt64E(i interface{}) (map[string]int64, error) { + m := map[string]int64{} + if i == nil { + return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int64", i, i) + } + + switch v := i.(type) { + case map[interface{}]interface{}: + for k, val := range v { + m[ToString(k)] = ToInt64(val) + } + return m, nil + case map[string]interface{}: + for k, val := range v { + m[k] = ToInt64(val) + } + return m, nil + case map[string]int64: + return v, nil + case string: + err := jsonStringToObject(v, &m) + return m, err + } + + if reflect.TypeOf(i).Kind() != reflect.Map { + return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int64", i, i) + } + mVal := reflect.ValueOf(m) + v := reflect.ValueOf(i) + for _, keyVal := range v.MapKeys() { + val, err := ToInt64E(v.MapIndex(keyVal).Interface()) + if err != nil { + return m, fmt.Errorf("unable to cast %#v of type %T to map[string]int64", i, i) + } + mVal.SetMapIndex(keyVal, reflect.ValueOf(val)) + } + return m, nil +} + +// ToSliceE casts an interface to a []interface{} type. +func ToSliceE(i interface{}) ([]interface{}, error) { + var s []interface{} + + switch v := i.(type) { + case []interface{}: + return append(s, v...), nil + case []map[string]interface{}: + for _, u := range v { + s = append(s, u) + } + return s, nil + default: + return s, fmt.Errorf("unable to cast %#v of type %T to []interface{}", i, i) + } +} + +// ToBoolSliceE casts an interface to a []bool type. +func ToBoolSliceE(i interface{}) ([]bool, error) { + if i == nil { + return []bool{}, fmt.Errorf("unable to cast %#v of type %T to []bool", i, i) + } + + switch v := i.(type) { + case []bool: + return v, nil + } + + kind := reflect.TypeOf(i).Kind() + switch kind { + case reflect.Slice, reflect.Array: + s := reflect.ValueOf(i) + a := make([]bool, s.Len()) + for j := 0; j < s.Len(); j++ { + val, err := ToBoolE(s.Index(j).Interface()) + if err != nil { + return []bool{}, fmt.Errorf("unable to cast %#v of type %T to []bool", i, i) + } + a[j] = val + } + return a, nil + default: + return []bool{}, fmt.Errorf("unable to cast %#v of type %T to []bool", i, i) + } +} + +// ToStringSliceE casts an interface to a []string type. +func ToStringSliceE(i interface{}) ([]string, error) { + var a []string + + switch v := i.(type) { + case []interface{}: + for _, u := range v { + a = append(a, ToString(u)) + } + return a, nil + case []string: + return v, nil + case []int8: + for _, u := range v { + a = append(a, ToString(u)) + } + return a, nil + case []int: + for _, u := range v { + a = append(a, ToString(u)) + } + return a, nil + case []int32: + for _, u := range v { + a = append(a, ToString(u)) + } + return a, nil + case []int64: + for _, u := range v { + a = append(a, ToString(u)) + } + return a, nil + case []float32: + for _, u := range v { + a = append(a, ToString(u)) + } + return a, nil + case []float64: + for _, u := range v { + a = append(a, ToString(u)) + } + return a, nil + case string: + return strings.Fields(v), nil + case []error: + for _, err := range i.([]error) { + a = append(a, err.Error()) + } + return a, nil + case interface{}: + str, err := ToStringE(v) + if err != nil { + return a, fmt.Errorf("unable to cast %#v of type %T to []string", i, i) + } + return []string{str}, nil + default: + return a, fmt.Errorf("unable to cast %#v of type %T to []string", i, i) + } +} + +// ToIntSliceE casts an interface to a []int type. +func ToIntSliceE(i interface{}) ([]int, error) { + if i == nil { + return []int{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i) + } + + switch v := i.(type) { + case []int: + return v, nil + } + + kind := reflect.TypeOf(i).Kind() + switch kind { + case reflect.Slice, reflect.Array: + s := reflect.ValueOf(i) + a := make([]int, s.Len()) + for j := 0; j < s.Len(); j++ { + val, err := ToIntE(s.Index(j).Interface()) + if err != nil { + return []int{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i) + } + a[j] = val + } + return a, nil + default: + return []int{}, fmt.Errorf("unable to cast %#v of type %T to []int", i, i) + } +} + +// ToDurationSliceE casts an interface to a []time.Duration type. +func ToDurationSliceE(i interface{}) ([]time.Duration, error) { + if i == nil { + return []time.Duration{}, fmt.Errorf("unable to cast %#v of type %T to []time.Duration", i, i) + } + + switch v := i.(type) { + case []time.Duration: + return v, nil + } + + kind := reflect.TypeOf(i).Kind() + switch kind { + case reflect.Slice, reflect.Array: + s := reflect.ValueOf(i) + a := make([]time.Duration, s.Len()) + for j := 0; j < s.Len(); j++ { + val, err := ToDurationE(s.Index(j).Interface()) + if err != nil { + return []time.Duration{}, fmt.Errorf("unable to cast %#v of type %T to []time.Duration", i, i) + } + a[j] = val + } + return a, nil + default: + return []time.Duration{}, fmt.Errorf("unable to cast %#v of type %T to []time.Duration", i, i) + } +} + +// StringToDate attempts to parse a string into a time.Time type using a +// predefined list of formats. If no suitable format is found, an error is +// returned. +func StringToDate(s string) (time.Time, error) { + return parseDateWith(s, time.UTC, timeFormats) +} + +// StringToDateInDefaultLocation casts an empty interface to a time.Time, +// interpreting inputs without a timezone to be in the given location, +// or the local timezone if nil. +func StringToDateInDefaultLocation(s string, location *time.Location) (time.Time, error) { + return parseDateWith(s, location, timeFormats) +} + +type timeFormatType int + +const ( + timeFormatNoTimezone timeFormatType = iota + timeFormatNamedTimezone + timeFormatNumericTimezone + timeFormatNumericAndNamedTimezone + timeFormatTimeOnly +) + +type timeFormat struct { + format string + typ timeFormatType +} + +func (f timeFormat) hasTimezone() bool { + // We don't include the formats with only named timezones, see + // https://github.com/golang/go/issues/19694#issuecomment-289103522 + return f.typ >= timeFormatNumericTimezone && f.typ <= timeFormatNumericAndNamedTimezone +} + +var timeFormats = []timeFormat{ + // Keep common formats at the top. + {"2006-01-02", timeFormatNoTimezone}, + {time.RFC3339, timeFormatNumericTimezone}, + {"2006-01-02T15:04:05", timeFormatNoTimezone}, // iso8601 without timezone + {time.RFC1123Z, timeFormatNumericTimezone}, + {time.RFC1123, timeFormatNamedTimezone}, + {time.RFC822Z, timeFormatNumericTimezone}, + {time.RFC822, timeFormatNamedTimezone}, + {time.RFC850, timeFormatNamedTimezone}, + {"2006-01-02 15:04:05.999999999 -0700 MST", timeFormatNumericAndNamedTimezone}, // Time.String() + {"2006-01-02T15:04:05-0700", timeFormatNumericTimezone}, // RFC3339 without timezone hh:mm colon + {"2006-01-02 15:04:05Z0700", timeFormatNumericTimezone}, // RFC3339 without T or timezone hh:mm colon + {"2006-01-02 15:04:05", timeFormatNoTimezone}, + {time.ANSIC, timeFormatNoTimezone}, + {time.UnixDate, timeFormatNamedTimezone}, + {time.RubyDate, timeFormatNumericTimezone}, + {"2006-01-02 15:04:05Z07:00", timeFormatNumericTimezone}, + {"02 Jan 2006", timeFormatNoTimezone}, + {"2006-01-02 15:04:05 -07:00", timeFormatNumericTimezone}, + {"2006-01-02 15:04:05 -0700", timeFormatNumericTimezone}, + {time.Kitchen, timeFormatTimeOnly}, + {time.Stamp, timeFormatTimeOnly}, + {time.StampMilli, timeFormatTimeOnly}, + {time.StampMicro, timeFormatTimeOnly}, + {time.StampNano, timeFormatTimeOnly}, +} + +func parseDateWith(s string, location *time.Location, formats []timeFormat) (d time.Time, e error) { + for _, format := range formats { + if d, e = time.Parse(format.format, s); e == nil { + + // Some time formats have a zone name, but no offset, so it gets + // put in that zone name (not the default one passed in to us), but + // without that zone's offset. So set the location manually. + if format.typ <= timeFormatNamedTimezone { + if location == nil { + location = time.Local + } + year, month, day := d.Date() + hour, min, sec := d.Clock() + d = time.Date(year, month, day, hour, min, sec, d.Nanosecond(), location) + } + + return + } + } + return d, fmt.Errorf("unable to parse date: %s", s) +} + +// jsonStringToObject attempts to unmarshall a string as JSON into +// the object passed as pointer. +func jsonStringToObject(s string, v interface{}) error { + data := []byte(s) + return json.Unmarshal(data, v) +} + +// toInt returns the int value of v if v or v's underlying type +// is an int. +// Note that this will return false for int64 etc. types. +func toInt(v interface{}) (int, bool) { + switch v := v.(type) { + case int: + return v, true + case time.Weekday: + return int(v), true + case time.Month: + return int(v), true + default: + return 0, false + } +} + +func trimZeroDecimal(s string) string { + var foundZero bool + for i := len(s); i > 0; i-- { + switch s[i-1] { + case '.': + if foundZero { + return s[:i-1] + } + case '0': + foundZero = true + default: + return s + } + } + return s +} diff --git a/vendor/github.com/spf13/cast/timeformattype_string.go b/vendor/github.com/spf13/cast/timeformattype_string.go new file mode 100644 index 0000000000..1524fc82ce --- /dev/null +++ b/vendor/github.com/spf13/cast/timeformattype_string.go @@ -0,0 +1,27 @@ +// Code generated by "stringer -type timeFormatType"; DO NOT EDIT. + +package cast + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[timeFormatNoTimezone-0] + _ = x[timeFormatNamedTimezone-1] + _ = x[timeFormatNumericTimezone-2] + _ = x[timeFormatNumericAndNamedTimezone-3] + _ = x[timeFormatTimeOnly-4] +} + +const _timeFormatType_name = "timeFormatNoTimezonetimeFormatNamedTimezonetimeFormatNumericTimezonetimeFormatNumericAndNamedTimezonetimeFormatTimeOnly" + +var _timeFormatType_index = [...]uint8{0, 20, 43, 68, 101, 119} + +func (i timeFormatType) String() string { + if i < 0 || i >= timeFormatType(len(_timeFormatType_index)-1) { + return "timeFormatType(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _timeFormatType_name[_timeFormatType_index[i]:_timeFormatType_index[i+1]] +} diff --git a/vendor/github.com/wk8/go-ordered-map/v2/.gitignore b/vendor/github.com/wk8/go-ordered-map/v2/.gitignore new file mode 100644 index 0000000000..57872d0f1e --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/.gitignore @@ -0,0 +1 @@ +/vendor/ diff --git a/vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml b/vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml new file mode 100644 index 0000000000..2417df10d9 --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/.golangci.yml @@ -0,0 +1,80 @@ +run: + tests: false + +linters: + disable-all: true + enable: + - asciicheck + - bidichk + - bodyclose + - containedctx + - contextcheck + - decorder + - depguard + - dogsled + - dupl + - durationcheck + - errcheck + - errchkjson + # FIXME: commented out as it crashes with 1.18 for now + # - errname + - errorlint + - exportloopref + - forbidigo + - funlen + - gci + - gochecknoglobals + - gochecknoinits + - gocognit + - goconst + - gocritic + - gocyclo + - godox + - gofmt + - gofumpt + - goheader + - goimports + - gomnd + - gomoddirectives + - gomodguard + - goprintffuncname + - gosec + - gosimple + - govet + - grouper + - ifshort + - importas + - ineffassign + - lll + - maintidx + - makezero + - misspell + - nakedret + - nilerr + - nilnil + - noctx + - nolintlint + - paralleltest + - prealloc + - predeclared + - promlinter + # FIXME: doesn't support 1.18 yet + # - revive + - rowserrcheck + - sqlclosecheck + - staticcheck + - structcheck + - stylecheck + - tagliatelle + - tenv + - testpackage + - thelper + - tparallel + - typecheck + - unconvert + - unparam + - unused + - varcheck + - varnamelen + - wastedassign + - whitespace diff --git a/vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md b/vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md new file mode 100644 index 0000000000..f27126f84f --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/CHANGELOG.md @@ -0,0 +1,38 @@ +# Changelog + +[comment]: # (Changes since last release go here) + +## 2.1.8 - Jun 27th 2023 + +* Added support for YAML serialization/deserialization + +## 2.1.7 - Apr 13th 2023 + +* Renamed test_utils.go to utils_test.go + +## 2.1.6 - Feb 15th 2023 + +* Added `GetAndMoveToBack()` and `GetAndMoveToFront()` methods + +## 2.1.5 - Dec 13th 2022 + +* Added `Value()` method + +## 2.1.4 - Dec 12th 2022 + +* Fixed a bug with UTF-8 special characters in JSON keys + +## 2.1.3 - Dec 11th 2022 + +* Added support for JSON marshalling/unmarshalling of wrapper of primitive types + +## 2.1.2 - Dec 10th 2022 +* Allowing to pass options to `New`, to give a capacity hint, or initial data +* Allowing to deserialize nested ordered maps from JSON without having to explicitly instantiate them +* Added the `AddPairs` method + +## 2.1.1 - Dec 9th 2022 +* Fixing a bug with JSON marshalling + +## 2.1.0 - Dec 7th 2022 +* Added support for JSON serialization/deserialization diff --git a/vendor/github.com/wk8/go-ordered-map/v2/LICENSE b/vendor/github.com/wk8/go-ordered-map/v2/LICENSE new file mode 100644 index 0000000000..8dada3edaf --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/vendor/github.com/wk8/go-ordered-map/v2/Makefile b/vendor/github.com/wk8/go-ordered-map/v2/Makefile new file mode 100644 index 0000000000..6e0e18a1b9 --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/Makefile @@ -0,0 +1,32 @@ +.DEFAULT_GOAL := all + +.PHONY: all +all: test_with_fuzz lint + +# the TEST_FLAGS env var can be set to eg run only specific tests +TEST_COMMAND = go test -v -count=1 -race -cover $(TEST_FLAGS) + +.PHONY: test +test: + $(TEST_COMMAND) + +.PHONY: bench +bench: + go test -bench=. + +FUZZ_TIME ?= 10s + +# see https://github.com/golang/go/issues/46312 +# and https://stackoverflow.com/a/72673487/4867444 +# if we end up having more fuzz tests +.PHONY: test_with_fuzz +test_with_fuzz: + $(TEST_COMMAND) -fuzz=FuzzRoundTripJSON -fuzztime=$(FUZZ_TIME) + $(TEST_COMMAND) -fuzz=FuzzRoundTripYAML -fuzztime=$(FUZZ_TIME) + +.PHONY: fuzz +fuzz: test_with_fuzz + +.PHONY: lint +lint: + golangci-lint run diff --git a/vendor/github.com/wk8/go-ordered-map/v2/README.md b/vendor/github.com/wk8/go-ordered-map/v2/README.md new file mode 100644 index 0000000000..b028944437 --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/README.md @@ -0,0 +1,154 @@ +[![Go Reference](https://pkg.go.dev/badge/github.com/wk8/go-ordered-map/v2.svg)](https://pkg.go.dev/github.com/wk8/go-ordered-map/v2) +[![Build Status](https://circleci.com/gh/wk8/go-ordered-map.svg?style=svg)](https://app.circleci.com/pipelines/github/wk8/go-ordered-map) + +# Golang Ordered Maps + +Same as regular maps, but also remembers the order in which keys were inserted, akin to [Python's `collections.OrderedDict`s](https://docs.python.org/3.7/library/collections.html#ordereddict-objects). + +It offers the following features: +* optimal runtime performance (all operations are constant time) +* optimal memory usage (only one copy of values, no unnecessary memory allocation) +* allows iterating from newest or oldest keys indifferently, without memory copy, allowing to `break` the iteration, and in time linear to the number of keys iterated over rather than the total length of the ordered map +* supports any generic types for both keys and values. If you're running go < 1.18, you can use [version 1](https://github.com/wk8/go-ordered-map/tree/v1) that takes and returns generic `interface{}`s instead of using generics +* idiomatic API, akin to that of [`container/list`](https://golang.org/pkg/container/list) +* support for JSON and YAML marshalling + +## Documentation + +[The full documentation is available on pkg.go.dev](https://pkg.go.dev/github.com/wk8/go-ordered-map/v2). + +## Installation +```bash +go get -u github.com/wk8/go-ordered-map/v2 +``` + +Or use your favorite golang vendoring tool! + +## Supported go versions + +Go >= 1.18 is required to use version >= 2 of this library, as it uses generics. + +If you're running go < 1.18, you can use [version 1](https://github.com/wk8/go-ordered-map/tree/v1) instead. + +## Example / usage + +```go +package main + +import ( + "fmt" + + "github.com/wk8/go-ordered-map/v2" +) + +func main() { + om := orderedmap.New[string, string]() + + om.Set("foo", "bar") + om.Set("bar", "baz") + om.Set("coucou", "toi") + + fmt.Println(om.Get("foo")) // => "bar", true + fmt.Println(om.Get("i dont exist")) // => "", false + + // iterating pairs from oldest to newest: + for pair := om.Oldest(); pair != nil; pair = pair.Next() { + fmt.Printf("%s => %s\n", pair.Key, pair.Value) + } // prints: + // foo => bar + // bar => baz + // coucou => toi + + // iterating over the 2 newest pairs: + i := 0 + for pair := om.Newest(); pair != nil; pair = pair.Prev() { + fmt.Printf("%s => %s\n", pair.Key, pair.Value) + i++ + if i >= 2 { + break + } + } // prints: + // coucou => toi + // bar => baz +} +``` + +An `OrderedMap`'s keys must implement `comparable`, and its values can be anything, for example: + +```go +type myStruct struct { + payload string +} + +func main() { + om := orderedmap.New[int, *myStruct]() + + om.Set(12, &myStruct{"foo"}) + om.Set(1, &myStruct{"bar"}) + + value, present := om.Get(12) + if !present { + panic("should be there!") + } + fmt.Println(value.payload) // => foo + + for pair := om.Oldest(); pair != nil; pair = pair.Next() { + fmt.Printf("%d => %s\n", pair.Key, pair.Value.payload) + } // prints: + // 12 => foo + // 1 => bar +} +``` + +Also worth noting that you can provision ordered maps with a capacity hint, as you would do by passing an optional hint to `make(map[K]V, capacity`): +```go +om := orderedmap.New[int, *myStruct](28) +``` + +You can also pass in some initial data to store in the map: +```go +om := orderedmap.New[int, string](orderedmap.WithInitialData[int, string]( + orderedmap.Pair[int, string]{ + Key: 12, + Value: "foo", + }, + orderedmap.Pair[int, string]{ + Key: 28, + Value: "bar", + }, +)) +``` + +`OrderedMap`s also support JSON serialization/deserialization, and preserves order: + +```go +// serialization +data, err := json.Marshal(om) +... + +// deserialization +om := orderedmap.New[string, string]() // or orderedmap.New[int, any](), or any type you expect +err := json.Unmarshal(data, &om) +... +``` + +Similarly, it also supports YAML serialization/deserialization using the yaml.v3 package, which also preserves order: + +```go +// serialization +data, err := yaml.Marshal(om) +... + +// deserialization +om := orderedmap.New[string, string]() // or orderedmap.New[int, any](), or any type you expect +err := yaml.Unmarshal(data, &om) +... +``` + +## Alternatives + +There are several other ordered map golang implementations out there, but I believe that at the time of writing none of them offer the same functionality as this library; more specifically: +* [iancoleman/orderedmap](https://github.com/iancoleman/orderedmap) only accepts `string` keys, its `Delete` operations are linear +* [cevaris/ordered_map](https://github.com/cevaris/ordered_map) uses a channel for iterations, and leaks goroutines if the iteration is interrupted before fully traversing the map +* [mantyr/iterator](https://github.com/mantyr/iterator) also uses a channel for iterations, and its `Delete` operations are linear +* [samdolan/go-ordered-map](https://github.com/samdolan/go-ordered-map) adds unnecessary locking (users should add their own locking instead if they need it), its `Delete` and `Get` operations are linear, iterations trigger a linear memory allocation diff --git a/vendor/github.com/wk8/go-ordered-map/v2/json.go b/vendor/github.com/wk8/go-ordered-map/v2/json.go new file mode 100644 index 0000000000..a545b536b3 --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/json.go @@ -0,0 +1,182 @@ +package orderedmap + +import ( + "bytes" + "encoding" + "encoding/json" + "fmt" + "reflect" + "unicode/utf8" + + "github.com/buger/jsonparser" + "github.com/mailru/easyjson/jwriter" +) + +var ( + _ json.Marshaler = &OrderedMap[int, any]{} + _ json.Unmarshaler = &OrderedMap[int, any]{} +) + +// MarshalJSON implements the json.Marshaler interface. +func (om *OrderedMap[K, V]) MarshalJSON() ([]byte, error) { //nolint:funlen + if om == nil || om.list == nil { + return []byte("null"), nil + } + + writer := jwriter.Writer{} + writer.RawByte('{') + + for pair, firstIteration := om.Oldest(), true; pair != nil; pair = pair.Next() { + if firstIteration { + firstIteration = false + } else { + writer.RawByte(',') + } + + switch key := any(pair.Key).(type) { + case string: + writer.String(key) + case encoding.TextMarshaler: + writer.RawByte('"') + writer.Raw(key.MarshalText()) + writer.RawByte('"') + case int: + writer.IntStr(key) + case int8: + writer.Int8Str(key) + case int16: + writer.Int16Str(key) + case int32: + writer.Int32Str(key) + case int64: + writer.Int64Str(key) + case uint: + writer.UintStr(key) + case uint8: + writer.Uint8Str(key) + case uint16: + writer.Uint16Str(key) + case uint32: + writer.Uint32Str(key) + case uint64: + writer.Uint64Str(key) + default: + + // this switch takes care of wrapper types around primitive types, such as + // type myType string + switch keyValue := reflect.ValueOf(key); keyValue.Type().Kind() { + case reflect.String: + writer.String(keyValue.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + writer.Int64Str(keyValue.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + writer.Uint64Str(keyValue.Uint()) + default: + return nil, fmt.Errorf("unsupported key type: %T", key) + } + } + + writer.RawByte(':') + // the error is checked at the end of the function + writer.Raw(json.Marshal(pair.Value)) //nolint:errchkjson + } + + writer.RawByte('}') + + return dumpWriter(&writer) +} + +func dumpWriter(writer *jwriter.Writer) ([]byte, error) { + if writer.Error != nil { + return nil, writer.Error + } + + var buf bytes.Buffer + buf.Grow(writer.Size()) + if _, err := writer.DumpTo(&buf); err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (om *OrderedMap[K, V]) UnmarshalJSON(data []byte) error { + if om.list == nil { + om.initialize(0) + } + + return jsonparser.ObjectEach( + data, + func(keyData []byte, valueData []byte, dataType jsonparser.ValueType, offset int) error { + if dataType == jsonparser.String { + // jsonparser removes the enclosing quotes; we need to restore them to make a valid JSON + valueData = data[offset-len(valueData)-2 : offset] + } + + var key K + var value V + + switch typedKey := any(&key).(type) { + case *string: + s, err := decodeUTF8(keyData) + if err != nil { + return err + } + *typedKey = s + case encoding.TextUnmarshaler: + if err := typedKey.UnmarshalText(keyData); err != nil { + return err + } + case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64: + if err := json.Unmarshal(keyData, typedKey); err != nil { + return err + } + default: + // this switch takes care of wrapper types around primitive types, such as + // type myType string + switch reflect.TypeOf(key).Kind() { + case reflect.String: + s, err := decodeUTF8(keyData) + if err != nil { + return err + } + + convertedKeyData := reflect.ValueOf(s).Convert(reflect.TypeOf(key)) + reflect.ValueOf(&key).Elem().Set(convertedKeyData) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if err := json.Unmarshal(keyData, &key); err != nil { + return err + } + default: + return fmt.Errorf("unsupported key type: %T", key) + } + } + + if err := json.Unmarshal(valueData, &value); err != nil { + return err + } + + om.Set(key, value) + return nil + }) +} + +func decodeUTF8(input []byte) (string, error) { + remaining, offset := input, 0 + runes := make([]rune, 0, len(remaining)) + + for len(remaining) > 0 { + r, size := utf8.DecodeRune(remaining) + if r == utf8.RuneError && size <= 1 { + return "", fmt.Errorf("not a valid UTF-8 string (at position %d): %s", offset, string(input)) + } + + runes = append(runes, r) + remaining = remaining[size:] + offset += size + } + + return string(runes), nil +} diff --git a/vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go b/vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go new file mode 100644 index 0000000000..0647141919 --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/orderedmap.go @@ -0,0 +1,296 @@ +// Package orderedmap implements an ordered map, i.e. a map that also keeps track of +// the order in which keys were inserted. +// +// All operations are constant-time. +// +// Github repo: https://github.com/wk8/go-ordered-map +// +package orderedmap + +import ( + "fmt" + + list "github.com/bahlo/generic-list-go" +) + +type Pair[K comparable, V any] struct { + Key K + Value V + + element *list.Element[*Pair[K, V]] +} + +type OrderedMap[K comparable, V any] struct { + pairs map[K]*Pair[K, V] + list *list.List[*Pair[K, V]] +} + +type initConfig[K comparable, V any] struct { + capacity int + initialData []Pair[K, V] +} + +type InitOption[K comparable, V any] func(config *initConfig[K, V]) + +// WithCapacity allows giving a capacity hint for the map, akin to the standard make(map[K]V, capacity). +func WithCapacity[K comparable, V any](capacity int) InitOption[K, V] { + return func(c *initConfig[K, V]) { + c.capacity = capacity + } +} + +// WithInitialData allows passing in initial data for the map. +func WithInitialData[K comparable, V any](initialData ...Pair[K, V]) InitOption[K, V] { + return func(c *initConfig[K, V]) { + c.initialData = initialData + if c.capacity < len(initialData) { + c.capacity = len(initialData) + } + } +} + +// New creates a new OrderedMap. +// options can either be one or several InitOption[K, V], or a single integer, +// which is then interpreted as a capacity hint, à la make(map[K]V, capacity). +func New[K comparable, V any](options ...any) *OrderedMap[K, V] { //nolint:varnamelen + orderedMap := &OrderedMap[K, V]{} + + var config initConfig[K, V] + for _, untypedOption := range options { + switch option := untypedOption.(type) { + case int: + if len(options) != 1 { + invalidOption() + } + config.capacity = option + + case InitOption[K, V]: + option(&config) + + default: + invalidOption() + } + } + + orderedMap.initialize(config.capacity) + orderedMap.AddPairs(config.initialData...) + + return orderedMap +} + +const invalidOptionMessage = `when using orderedmap.New[K,V]() with options, either provide one or several InitOption[K, V]; or a single integer which is then interpreted as a capacity hint, à la make(map[K]V, capacity).` //nolint:lll + +func invalidOption() { panic(invalidOptionMessage) } + +func (om *OrderedMap[K, V]) initialize(capacity int) { + om.pairs = make(map[K]*Pair[K, V], capacity) + om.list = list.New[*Pair[K, V]]() +} + +// Get looks for the given key, and returns the value associated with it, +// or V's nil value if not found. The boolean it returns says whether the key is present in the map. +func (om *OrderedMap[K, V]) Get(key K) (val V, present bool) { + if pair, present := om.pairs[key]; present { + return pair.Value, true + } + + return +} + +// Load is an alias for Get, mostly to present an API similar to `sync.Map`'s. +func (om *OrderedMap[K, V]) Load(key K) (V, bool) { + return om.Get(key) +} + +// Value returns the value associated with the given key or the zero value. +func (om *OrderedMap[K, V]) Value(key K) (val V) { + if pair, present := om.pairs[key]; present { + val = pair.Value + } + return +} + +// GetPair looks for the given key, and returns the pair associated with it, +// or nil if not found. The Pair struct can then be used to iterate over the ordered map +// from that point, either forward or backward. +func (om *OrderedMap[K, V]) GetPair(key K) *Pair[K, V] { + return om.pairs[key] +} + +// Set sets the key-value pair, and returns what `Get` would have returned +// on that key prior to the call to `Set`. +func (om *OrderedMap[K, V]) Set(key K, value V) (val V, present bool) { + if pair, present := om.pairs[key]; present { + oldValue := pair.Value + pair.Value = value + return oldValue, true + } + + pair := &Pair[K, V]{ + Key: key, + Value: value, + } + pair.element = om.list.PushBack(pair) + om.pairs[key] = pair + + return +} + +// AddPairs allows setting multiple pairs at a time. It's equivalent to calling +// Set on each pair sequentially. +func (om *OrderedMap[K, V]) AddPairs(pairs ...Pair[K, V]) { + for _, pair := range pairs { + om.Set(pair.Key, pair.Value) + } +} + +// Store is an alias for Set, mostly to present an API similar to `sync.Map`'s. +func (om *OrderedMap[K, V]) Store(key K, value V) (V, bool) { + return om.Set(key, value) +} + +// Delete removes the key-value pair, and returns what `Get` would have returned +// on that key prior to the call to `Delete`. +func (om *OrderedMap[K, V]) Delete(key K) (val V, present bool) { + if pair, present := om.pairs[key]; present { + om.list.Remove(pair.element) + delete(om.pairs, key) + return pair.Value, true + } + return +} + +// Len returns the length of the ordered map. +func (om *OrderedMap[K, V]) Len() int { + if om == nil || om.pairs == nil { + return 0 + } + return len(om.pairs) +} + +// Oldest returns a pointer to the oldest pair. It's meant to be used to iterate on the ordered map's +// pairs from the oldest to the newest, e.g.: +// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) } +func (om *OrderedMap[K, V]) Oldest() *Pair[K, V] { + if om == nil || om.list == nil { + return nil + } + return listElementToPair(om.list.Front()) +} + +// Newest returns a pointer to the newest pair. It's meant to be used to iterate on the ordered map's +// pairs from the newest to the oldest, e.g.: +// for pair := orderedMap.Oldest(); pair != nil; pair = pair.Next() { fmt.Printf("%v => %v\n", pair.Key, pair.Value) } +func (om *OrderedMap[K, V]) Newest() *Pair[K, V] { + if om == nil || om.list == nil { + return nil + } + return listElementToPair(om.list.Back()) +} + +// Next returns a pointer to the next pair. +func (p *Pair[K, V]) Next() *Pair[K, V] { + return listElementToPair(p.element.Next()) +} + +// Prev returns a pointer to the previous pair. +func (p *Pair[K, V]) Prev() *Pair[K, V] { + return listElementToPair(p.element.Prev()) +} + +func listElementToPair[K comparable, V any](element *list.Element[*Pair[K, V]]) *Pair[K, V] { + if element == nil { + return nil + } + return element.Value +} + +// KeyNotFoundError may be returned by functions in this package when they're called with keys that are not present +// in the map. +type KeyNotFoundError[K comparable] struct { + MissingKey K +} + +func (e *KeyNotFoundError[K]) Error() string { + return fmt.Sprintf("missing key: %v", e.MissingKey) +} + +// MoveAfter moves the value associated with key to its new position after the one associated with markKey. +// Returns an error iff key or markKey are not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveAfter(key, markKey K) error { + elements, err := om.getElements(key, markKey) + if err != nil { + return err + } + om.list.MoveAfter(elements[0], elements[1]) + return nil +} + +// MoveBefore moves the value associated with key to its new position before the one associated with markKey. +// Returns an error iff key or markKey are not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveBefore(key, markKey K) error { + elements, err := om.getElements(key, markKey) + if err != nil { + return err + } + om.list.MoveBefore(elements[0], elements[1]) + return nil +} + +func (om *OrderedMap[K, V]) getElements(keys ...K) ([]*list.Element[*Pair[K, V]], error) { + elements := make([]*list.Element[*Pair[K, V]], len(keys)) + for i, k := range keys { + pair, present := om.pairs[k] + if !present { + return nil, &KeyNotFoundError[K]{k} + } + elements[i] = pair.element + } + return elements, nil +} + +// MoveToBack moves the value associated with key to the back of the ordered map, +// i.e. makes it the newest pair in the map. +// Returns an error iff key is not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveToBack(key K) error { + _, err := om.GetAndMoveToBack(key) + return err +} + +// MoveToFront moves the value associated with key to the front of the ordered map, +// i.e. makes it the oldest pair in the map. +// Returns an error iff key is not present in the map. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) MoveToFront(key K) error { + _, err := om.GetAndMoveToFront(key) + return err +} + +// GetAndMoveToBack combines Get and MoveToBack in the same call. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) GetAndMoveToBack(key K) (val V, err error) { + if pair, present := om.pairs[key]; present { + val = pair.Value + om.list.MoveToBack(pair.element) + } else { + err = &KeyNotFoundError[K]{key} + } + + return +} + +// GetAndMoveToFront combines Get and MoveToFront in the same call. If an error is returned, +// it will be a KeyNotFoundError. +func (om *OrderedMap[K, V]) GetAndMoveToFront(key K) (val V, err error) { + if pair, present := om.pairs[key]; present { + val = pair.Value + om.list.MoveToFront(pair.element) + } else { + err = &KeyNotFoundError[K]{key} + } + + return +} diff --git a/vendor/github.com/wk8/go-ordered-map/v2/yaml.go b/vendor/github.com/wk8/go-ordered-map/v2/yaml.go new file mode 100644 index 0000000000..602247128f --- /dev/null +++ b/vendor/github.com/wk8/go-ordered-map/v2/yaml.go @@ -0,0 +1,71 @@ +package orderedmap + +import ( + "fmt" + + "gopkg.in/yaml.v3" +) + +var ( + _ yaml.Marshaler = &OrderedMap[int, any]{} + _ yaml.Unmarshaler = &OrderedMap[int, any]{} +) + +// MarshalYAML implements the yaml.Marshaler interface. +func (om *OrderedMap[K, V]) MarshalYAML() (interface{}, error) { + if om == nil { + return []byte("null"), nil + } + + node := yaml.Node{ + Kind: yaml.MappingNode, + } + + for pair := om.Oldest(); pair != nil; pair = pair.Next() { + key, value := pair.Key, pair.Value + + keyNode := &yaml.Node{} + + // serialize key to yaml, then deserialize it back into the node + // this is a hack to get the correct tag for the key + if err := keyNode.Encode(key); err != nil { + return nil, err + } + + valueNode := &yaml.Node{} + if err := valueNode.Encode(value); err != nil { + return nil, err + } + + node.Content = append(node.Content, keyNode, valueNode) + } + + return &node, nil +} + +// UnmarshalYAML implements the yaml.Unmarshaler interface. +func (om *OrderedMap[K, V]) UnmarshalYAML(value *yaml.Node) error { + if value.Kind != yaml.MappingNode { + return fmt.Errorf("pipeline must contain YAML mapping, has %v", value.Kind) + } + + if om.list == nil { + om.initialize(0) + } + + for index := 0; index < len(value.Content); index += 2 { + var key K + var val V + + if err := value.Content[index].Decode(&key); err != nil { + return err + } + if err := value.Content[index+1].Decode(&val); err != nil { + return err + } + + om.Set(key, val) + } + + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/LICENSE b/vendor/github.com/yosida95/uritemplate/v3/LICENSE new file mode 100644 index 0000000000..79e8f87572 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/LICENSE @@ -0,0 +1,25 @@ +Copyright (C) 2016, Kohei YOSHIDA . All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + * Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/yosida95/uritemplate/v3/README.rst b/vendor/github.com/yosida95/uritemplate/v3/README.rst new file mode 100644 index 0000000000..6815d0a465 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/README.rst @@ -0,0 +1,46 @@ +uritemplate +=========== + +`uritemplate`_ is a Go implementation of `URI Template`_ [RFC6570] with +full functionality of URI Template Level 4. + +uritemplate can also generate a regexp that matches expansion of the +URI Template from a URI Template. + +Getting Started +--------------- + +Installation +~~~~~~~~~~~~ + +.. code-block:: sh + + $ go get -u github.com/yosida95/uritemplate/v3 + +Documentation +~~~~~~~~~~~~~ + +The documentation is available on GoDoc_. + +Examples +-------- + +See `examples on GoDoc`_. + +License +------- + +`uritemplate`_ is distributed under the BSD 3-Clause license. +PLEASE READ ./LICENSE carefully and follow its clauses to use this software. + +Author +------ + +yosida95_ + + +.. _`URI Template`: https://tools.ietf.org/html/rfc6570 +.. _Godoc: https://godoc.org/github.com/yosida95/uritemplate +.. _`examples on GoDoc`: https://godoc.org/github.com/yosida95/uritemplate#pkg-examples +.. _yosida95: https://yosida95.com/ +.. _uritemplate: https://github.com/yosida95/uritemplate diff --git a/vendor/github.com/yosida95/uritemplate/v3/compile.go b/vendor/github.com/yosida95/uritemplate/v3/compile.go new file mode 100644 index 0000000000..bd774d15d0 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/compile.go @@ -0,0 +1,224 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" + "unicode/utf8" +) + +type compiler struct { + prog *prog +} + +func (c *compiler) init() { + c.prog = &prog{} +} + +func (c *compiler) op(opcode progOpcode) uint32 { + i := len(c.prog.op) + c.prog.op = append(c.prog.op, progOp{code: opcode}) + return uint32(i) +} + +func (c *compiler) opWithRune(opcode progOpcode, r rune) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).r = r + return addr +} + +func (c *compiler) opWithRuneClass(opcode progOpcode, rc runeClass) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).rc = rc + return addr +} + +func (c *compiler) opWithAddr(opcode progOpcode, absaddr uint32) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).i = absaddr + return addr +} + +func (c *compiler) opWithAddrDelta(opcode progOpcode, delta uint32) uint32 { + return c.opWithAddr(opcode, uint32(len(c.prog.op))+delta) +} + +func (c *compiler) opWithName(opcode progOpcode, name string) uint32 { + addr := c.op(opcode) + (&c.prog.op[addr]).name = name + return addr +} + +func (c *compiler) compileString(str string) { + for i := 0; i < len(str); { + // NOTE(yosida95): It is confirmed at parse time that literals + // consist of only valid-UTF8 runes. + r, size := utf8.DecodeRuneInString(str[i:]) + c.opWithRune(opRune, r) + i += size + } +} + +func (c *compiler) compileRuneClass(rc runeClass, maxlen int) { + for i := 0; i < maxlen; i++ { + if i > 0 { + c.opWithAddrDelta(opSplit, 7) + } + c.opWithAddrDelta(opSplit, 3) // raw rune or pct-encoded + c.opWithRuneClass(opRuneClass, rc) // raw rune + c.opWithAddrDelta(opJmp, 4) // + c.opWithRune(opRune, '%') // pct-encoded + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithRuneClass(opRuneClass, runeClassPctE) // + } +} + +func (c *compiler) compileRuneClassInfinite(rc runeClass) { + start := c.opWithAddrDelta(opSplit, 3) // raw rune or pct-encoded + c.opWithRuneClass(opRuneClass, rc) // raw rune + c.opWithAddrDelta(opJmp, 4) // + c.opWithRune(opRune, '%') // pct-encoded + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithRuneClass(opRuneClass, runeClassPctE) // + c.opWithAddrDelta(opSplit, 2) // loop + c.opWithAddr(opJmp, start) // +} + +func (c *compiler) compileVarspecValue(spec varspec, expr *expression) { + var specname string + if spec.maxlen > 0 { + specname = fmt.Sprintf("%s:%d", spec.name, spec.maxlen) + } else { + specname = spec.name + } + + c.prog.numCap++ + + c.opWithName(opCapStart, specname) + + split := c.op(opSplit) + if spec.maxlen > 0 { + c.compileRuneClass(expr.allow, spec.maxlen) + } else { + c.compileRuneClassInfinite(expr.allow) + } + + capEnd := c.opWithName(opCapEnd, specname) + c.prog.op[split].i = capEnd +} + +func (c *compiler) compileVarspec(spec varspec, expr *expression) { + switch { + case expr.named && spec.explode: + split1 := c.op(opSplit) + noop := c.op(opNoop) + c.compileString(spec.name) + + split2 := c.op(opSplit) + c.opWithRune(opRune, '=') + c.compileVarspecValue(spec, expr) + + split3 := c.op(opSplit) + c.compileString(expr.sep) + c.opWithAddr(opJmp, noop) + + c.prog.op[split2].i = uint32(len(c.prog.op)) + c.compileString(expr.ifemp) + c.opWithAddr(opJmp, split3) + + c.prog.op[split1].i = uint32(len(c.prog.op)) + c.prog.op[split3].i = uint32(len(c.prog.op)) + + case expr.named && !spec.explode: + c.compileString(spec.name) + + split2 := c.op(opSplit) + c.opWithRune(opRune, '=') + + split3 := c.op(opSplit) + + split4 := c.op(opSplit) + c.compileVarspecValue(spec, expr) + + split5 := c.op(opSplit) + c.prog.op[split4].i = split5 + c.compileString(",") + c.opWithAddr(opJmp, split4) + + c.prog.op[split3].i = uint32(len(c.prog.op)) + c.compileString(",") + jmp1 := c.op(opJmp) + + c.prog.op[split2].i = uint32(len(c.prog.op)) + c.compileString(expr.ifemp) + + c.prog.op[split5].i = uint32(len(c.prog.op)) + c.prog.op[jmp1].i = uint32(len(c.prog.op)) + + case !expr.named: + start := uint32(len(c.prog.op)) + c.compileVarspecValue(spec, expr) + + split1 := c.op(opSplit) + jmp := c.op(opJmp) + + c.prog.op[split1].i = uint32(len(c.prog.op)) + if spec.explode { + c.compileString(expr.sep) + } else { + c.opWithRune(opRune, ',') + } + c.opWithAddr(opJmp, start) + + c.prog.op[jmp].i = uint32(len(c.prog.op)) + } +} + +func (c *compiler) compileExpression(expr *expression) { + if len(expr.vars) < 1 { + return + } + + split1 := c.op(opSplit) + c.compileString(expr.first) + + for i, size := 0, len(expr.vars); i < size; i++ { + spec := expr.vars[i] + + split2 := c.op(opSplit) + if i > 0 { + split3 := c.op(opSplit) + c.compileString(expr.sep) + c.prog.op[split3].i = uint32(len(c.prog.op)) + } + c.compileVarspec(spec, expr) + c.prog.op[split2].i = uint32(len(c.prog.op)) + } + + c.prog.op[split1].i = uint32(len(c.prog.op)) +} + +func (c *compiler) compileLiterals(lt literals) { + c.compileString(string(lt)) +} + +func (c *compiler) compile(tmpl *Template) { + c.op(opLineBegin) + for i := range tmpl.exprs { + expr := tmpl.exprs[i] + switch expr := expr.(type) { + default: + panic("unhandled expression") + case *expression: + c.compileExpression(expr) + case literals: + c.compileLiterals(expr) + } + } + c.op(opLineEnd) + c.op(opEnd) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/equals.go b/vendor/github.com/yosida95/uritemplate/v3/equals.go new file mode 100644 index 0000000000..aa59a5c030 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/equals.go @@ -0,0 +1,53 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +type CompareFlags uint8 + +const ( + CompareVarname CompareFlags = 1 << iota +) + +// Equals reports whether or not two URI Templates t1 and t2 are equivalent. +func Equals(t1 *Template, t2 *Template, flags CompareFlags) bool { + if len(t1.exprs) != len(t2.exprs) { + return false + } + for i := 0; i < len(t1.exprs); i++ { + switch t1 := t1.exprs[i].(type) { + case literals: + t2, ok := t2.exprs[i].(literals) + if !ok { + return false + } + if t1 != t2 { + return false + } + case *expression: + t2, ok := t2.exprs[i].(*expression) + if !ok { + return false + } + if t1.op != t2.op || len(t1.vars) != len(t2.vars) { + return false + } + for n := 0; n < len(t1.vars); n++ { + v1 := t1.vars[n] + v2 := t2.vars[n] + if flags&CompareVarname == CompareVarname && v1.name != v2.name { + return false + } + if v1.maxlen != v2.maxlen || v1.explode != v2.explode { + return false + } + } + default: + panic("unhandled case") + } + } + return true +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/error.go b/vendor/github.com/yosida95/uritemplate/v3/error.go new file mode 100644 index 0000000000..2fd34a8080 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/error.go @@ -0,0 +1,16 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" +) + +func errorf(pos int, format string, a ...interface{}) error { + msg := fmt.Sprintf(format, a...) + return fmt.Errorf("uritemplate:%d:%s", pos, msg) +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/escape.go b/vendor/github.com/yosida95/uritemplate/v3/escape.go new file mode 100644 index 0000000000..6d27e693af --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/escape.go @@ -0,0 +1,190 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +var ( + hex = []byte("0123456789ABCDEF") + // reserved = gen-delims / sub-delims + // gen-delims = ":" / "/" / "?" / "#" / "[" / "]" / "@" + // sub-delims = "!" / "$" / "&" / "’" / "(" / ")" + // / "*" / "+" / "," / ";" / "=" + rangeReserved = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x21, Hi: 0x21, Stride: 1}, // '!' + {Lo: 0x23, Hi: 0x24, Stride: 1}, // '#' - '$' + {Lo: 0x26, Hi: 0x2C, Stride: 1}, // '&' - ',' + {Lo: 0x2F, Hi: 0x2F, Stride: 1}, // '/' + {Lo: 0x3A, Hi: 0x3B, Stride: 1}, // ':' - ';' + {Lo: 0x3D, Hi: 0x3D, Stride: 1}, // '=' + {Lo: 0x3F, Hi: 0x40, Stride: 1}, // '?' - '@' + {Lo: 0x5B, Hi: 0x5B, Stride: 1}, // '[' + {Lo: 0x5D, Hi: 0x5D, Stride: 1}, // ']' + }, + LatinOffset: 9, + } + reReserved = `\x21\x23\x24\x26-\x2c\x2f\x3a\x3b\x3d\x3f\x40\x5b\x5d` + // ALPHA = %x41-5A / %x61-7A + // DIGIT = %x30-39 + // unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + rangeUnreserved = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x2D, Hi: 0x2E, Stride: 1}, // '-' - '.' + {Lo: 0x30, Hi: 0x39, Stride: 1}, // '0' - '9' + {Lo: 0x41, Hi: 0x5A, Stride: 1}, // 'A' - 'Z' + {Lo: 0x5F, Hi: 0x5F, Stride: 1}, // '_' + {Lo: 0x61, Hi: 0x7A, Stride: 1}, // 'a' - 'z' + {Lo: 0x7E, Hi: 0x7E, Stride: 1}, // '~' + }, + } + reUnreserved = `\x2d\x2e\x30-\x39\x41-\x5a\x5f\x61-\x7a\x7e` +) + +type runeClass uint8 + +const ( + runeClassU runeClass = 1 << iota + runeClassR + runeClassPctE + runeClassLast + + runeClassUR = runeClassU | runeClassR +) + +var runeClassNames = []string{ + "U", + "R", + "pct-encoded", +} + +func (rc runeClass) String() string { + ret := make([]string, 0, len(runeClassNames)) + for i, j := 0, runeClass(1); j < runeClassLast; j <<= 1 { + if rc&j == j { + ret = append(ret, runeClassNames[i]) + } + i++ + } + return strings.Join(ret, "+") +} + +func pctEncode(w *strings.Builder, r rune) { + if s := r >> 24 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r >> 16 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r >> 8 & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } + if s := r & 0xff; s > 0 { + w.Write([]byte{'%', hex[s/16], hex[s%16]}) + } +} + +func unhex(c byte) byte { + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + } + return 0 +} + +func ishex(c byte) bool { + switch { + case '0' <= c && c <= '9': + return true + case 'a' <= c && c <= 'f': + return true + case 'A' <= c && c <= 'F': + return true + default: + return false + } +} + +func pctDecode(s string) string { + size := len(s) + for i := 0; i < len(s); { + switch s[i] { + case '%': + size -= 2 + i += 3 + default: + i++ + } + } + if size == len(s) { + return s + } + + buf := make([]byte, size) + j := 0 + for i := 0; i < len(s); { + switch c := s[i]; c { + case '%': + buf[j] = unhex(s[i+1])<<4 | unhex(s[i+2]) + i += 3 + j++ + default: + buf[j] = c + i++ + j++ + } + } + return string(buf) +} + +type escapeFunc func(*strings.Builder, string) error + +func escapeLiteral(w *strings.Builder, v string) error { + w.WriteString(v) + return nil +} + +func escapeExceptU(w *strings.Builder, v string) error { + for i := 0; i < len(v); { + r, size := utf8.DecodeRuneInString(v[i:]) + if r == utf8.RuneError { + return errorf(i, "invalid encoding") + } + if unicode.Is(rangeUnreserved, r) { + w.WriteRune(r) + } else { + pctEncode(w, r) + } + i += size + } + return nil +} + +func escapeExceptUR(w *strings.Builder, v string) error { + for i := 0; i < len(v); { + r, size := utf8.DecodeRuneInString(v[i:]) + if r == utf8.RuneError { + return errorf(i, "invalid encoding") + } + // TODO(yosida95): is pct-encoded triplets allowed here? + if unicode.In(r, rangeUnreserved, rangeReserved) { + w.WriteRune(r) + } else { + pctEncode(w, r) + } + i += size + } + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/expression.go b/vendor/github.com/yosida95/uritemplate/v3/expression.go new file mode 100644 index 0000000000..4858c2ddef --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/expression.go @@ -0,0 +1,173 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "regexp" + "strconv" + "strings" +) + +type template interface { + expand(*strings.Builder, Values) error + regexp(*strings.Builder) +} + +type literals string + +func (l literals) expand(b *strings.Builder, _ Values) error { + b.WriteString(string(l)) + return nil +} + +func (l literals) regexp(b *strings.Builder) { + b.WriteString("(?:") + b.WriteString(regexp.QuoteMeta(string(l))) + b.WriteByte(')') +} + +type varspec struct { + name string + maxlen int + explode bool +} + +type expression struct { + vars []varspec + op parseOp + first string + sep string + named bool + ifemp string + escape escapeFunc + allow runeClass +} + +func (e *expression) init() { + switch e.op { + case parseOpSimple: + e.sep = "," + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpPlus: + e.sep = "," + e.escape = escapeExceptUR + e.allow = runeClassUR + case parseOpCrosshatch: + e.first = "#" + e.sep = "," + e.escape = escapeExceptUR + e.allow = runeClassUR + case parseOpDot: + e.first = "." + e.sep = "." + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpSlash: + e.first = "/" + e.sep = "/" + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpSemicolon: + e.first = ";" + e.sep = ";" + e.named = true + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpQuestion: + e.first = "?" + e.sep = "&" + e.named = true + e.ifemp = "=" + e.escape = escapeExceptU + e.allow = runeClassU + case parseOpAmpersand: + e.first = "&" + e.sep = "&" + e.named = true + e.ifemp = "=" + e.escape = escapeExceptU + e.allow = runeClassU + } +} + +func (e *expression) expand(w *strings.Builder, values Values) error { + first := true + for _, varspec := range e.vars { + value := values.Get(varspec.name) + if !value.Valid() { + continue + } + + if first { + w.WriteString(e.first) + first = false + } else { + w.WriteString(e.sep) + } + + if err := value.expand(w, varspec, e); err != nil { + return err + } + + } + return nil +} + +func (e *expression) regexp(b *strings.Builder) { + if e.first != "" { + b.WriteString("(?:") // $1 + b.WriteString(regexp.QuoteMeta(e.first)) + } + b.WriteByte('(') // $2 + runeClassToRegexp(b, e.allow, e.named || e.vars[0].explode) + if len(e.vars) > 1 || e.vars[0].explode { + max := len(e.vars) - 1 + for i := 0; i < len(e.vars); i++ { + if e.vars[i].explode { + max = -1 + break + } + } + + b.WriteString("(?:") // $3 + b.WriteString(regexp.QuoteMeta(e.sep)) + runeClassToRegexp(b, e.allow, e.named || max < 0) + b.WriteByte(')') // $3 + if max > 0 { + b.WriteString("{0,") + b.WriteString(strconv.Itoa(max)) + b.WriteByte('}') + } else { + b.WriteByte('*') + } + } + b.WriteByte(')') // $2 + if e.first != "" { + b.WriteByte(')') // $1 + } + b.WriteByte('?') +} + +func runeClassToRegexp(b *strings.Builder, class runeClass, named bool) { + b.WriteString("(?:(?:[") + if class&runeClassR == 0 { + b.WriteString(`\x2c`) + if named { + b.WriteString(`\x3d`) + } + } + if class&runeClassU == runeClassU { + b.WriteString(reUnreserved) + } + if class&runeClassR == runeClassR { + b.WriteString(reReserved) + } + b.WriteString("]") + b.WriteString("|%[[:xdigit:]][[:xdigit:]]") + b.WriteString(")*)") +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/machine.go b/vendor/github.com/yosida95/uritemplate/v3/machine.go new file mode 100644 index 0000000000..7b1d0b518d --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/machine.go @@ -0,0 +1,23 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +// threadList implements https://research.swtch.com/sparse. +type threadList struct { + dense []threadEntry + sparse []uint32 +} + +type threadEntry struct { + pc uint32 + t *thread +} + +type thread struct { + op *progOp + cap map[string][]int +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/match.go b/vendor/github.com/yosida95/uritemplate/v3/match.go new file mode 100644 index 0000000000..02fe6385a3 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/match.go @@ -0,0 +1,213 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "bytes" + "unicode" + "unicode/utf8" +) + +type matcher struct { + prog *prog + + list1 threadList + list2 threadList + matched bool + cap map[string][]int + + input string +} + +func (m *matcher) at(pos int) (rune, int, bool) { + if l := len(m.input); pos < l { + c := m.input[pos] + if c < utf8.RuneSelf { + return rune(c), 1, pos+1 < l + } + r, size := utf8.DecodeRuneInString(m.input[pos:]) + return r, size, pos+size < l + } + return -1, 0, false +} + +func (m *matcher) add(list *threadList, pc uint32, pos int, next bool, cap map[string][]int) { + if i := list.sparse[pc]; i < uint32(len(list.dense)) && list.dense[i].pc == pc { + return + } + + n := len(list.dense) + list.dense = list.dense[:n+1] + list.sparse[pc] = uint32(n) + + e := &list.dense[n] + e.pc = pc + e.t = nil + + op := &m.prog.op[pc] + switch op.code { + default: + panic("unhandled opcode") + case opRune, opRuneClass, opEnd: + e.t = &thread{ + op: &m.prog.op[pc], + cap: make(map[string][]int, len(m.cap)), + } + for k, v := range cap { + e.t.cap[k] = make([]int, len(v)) + copy(e.t.cap[k], v) + } + case opLineBegin: + if pos == 0 { + m.add(list, pc+1, pos, next, cap) + } + case opLineEnd: + if !next { + m.add(list, pc+1, pos, next, cap) + } + case opCapStart, opCapEnd: + ocap := make(map[string][]int, len(m.cap)) + for k, v := range cap { + ocap[k] = make([]int, len(v)) + copy(ocap[k], v) + } + ocap[op.name] = append(ocap[op.name], pos) + m.add(list, pc+1, pos, next, ocap) + case opSplit: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmp: + m.add(list, op.i, pos, next, cap) + case opJmpIfNotDefined: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmpIfNotFirst: + m.add(list, pc+1, pos, next, cap) + m.add(list, op.i, pos, next, cap) + case opJmpIfNotEmpty: + m.add(list, op.i, pos, next, cap) + m.add(list, pc+1, pos, next, cap) + case opNoop: + m.add(list, pc+1, pos, next, cap) + } +} + +func (m *matcher) step(clist *threadList, nlist *threadList, r rune, pos int, nextPos int, next bool) { + debug.Printf("===== %q =====", string(r)) + for i := 0; i < len(clist.dense); i++ { + e := clist.dense[i] + if debug { + var buf bytes.Buffer + dumpProg(&buf, m.prog, e.pc) + debug.Printf("\n%s", buf.String()) + } + if e.t == nil { + continue + } + + t := e.t + op := t.op + switch op.code { + default: + panic("unhandled opcode") + case opRune: + if op.r == r { + m.add(nlist, e.pc+1, nextPos, next, t.cap) + } + case opRuneClass: + ret := false + if !ret && op.rc&runeClassU == runeClassU { + ret = ret || unicode.Is(rangeUnreserved, r) + } + if !ret && op.rc&runeClassR == runeClassR { + ret = ret || unicode.Is(rangeReserved, r) + } + if !ret && op.rc&runeClassPctE == runeClassPctE { + ret = ret || unicode.Is(unicode.ASCII_Hex_Digit, r) + } + if ret { + m.add(nlist, e.pc+1, nextPos, next, t.cap) + } + case opEnd: + m.matched = true + for k, v := range t.cap { + m.cap[k] = make([]int, len(v)) + copy(m.cap[k], v) + } + clist.dense = clist.dense[:0] + } + } + clist.dense = clist.dense[:0] +} + +func (m *matcher) match() bool { + pos := 0 + clist, nlist := &m.list1, &m.list2 + for { + if len(clist.dense) == 0 && m.matched { + break + } + r, width, next := m.at(pos) + if !m.matched { + m.add(clist, 0, pos, next, m.cap) + } + m.step(clist, nlist, r, pos, pos+width, next) + + if width < 1 { + break + } + pos += width + + clist, nlist = nlist, clist + } + return m.matched +} + +func (tmpl *Template) Match(expansion string) Values { + tmpl.mu.Lock() + if tmpl.prog == nil { + c := compiler{} + c.init() + c.compile(tmpl) + tmpl.prog = c.prog + } + prog := tmpl.prog + tmpl.mu.Unlock() + + n := len(prog.op) + m := matcher{ + prog: prog, + list1: threadList{ + dense: make([]threadEntry, 0, n), + sparse: make([]uint32, n), + }, + list2: threadList{ + dense: make([]threadEntry, 0, n), + sparse: make([]uint32, n), + }, + cap: make(map[string][]int, prog.numCap), + input: expansion, + } + if !m.match() { + return nil + } + + match := make(Values, len(m.cap)) + for name, indices := range m.cap { + v := Value{V: make([]string, len(indices)/2)} + for i := range v.V { + v.V[i] = pctDecode(expansion[indices[2*i]:indices[2*i+1]]) + } + if len(v.V) == 1 { + v.T = ValueTypeString + } else { + v.T = ValueTypeList + } + match[name] = v + } + return match +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/parse.go b/vendor/github.com/yosida95/uritemplate/v3/parse.go new file mode 100644 index 0000000000..fd38a682f1 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/parse.go @@ -0,0 +1,277 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "fmt" + "unicode" + "unicode/utf8" +) + +type parseOp int + +const ( + parseOpSimple parseOp = iota + parseOpPlus + parseOpCrosshatch + parseOpDot + parseOpSlash + parseOpSemicolon + parseOpQuestion + parseOpAmpersand +) + +var ( + rangeVarchar = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x0030, Hi: 0x0039, Stride: 1}, // '0' - '9' + {Lo: 0x0041, Hi: 0x005A, Stride: 1}, // 'A' - 'Z' + {Lo: 0x005F, Hi: 0x005F, Stride: 1}, // '_' + {Lo: 0x0061, Hi: 0x007A, Stride: 1}, // 'a' - 'z' + }, + LatinOffset: 4, + } + rangeLiterals = &unicode.RangeTable{ + R16: []unicode.Range16{ + {Lo: 0x0021, Hi: 0x0021, Stride: 1}, // '!' + {Lo: 0x0023, Hi: 0x0024, Stride: 1}, // '#' - '$' + {Lo: 0x0026, Hi: 0x003B, Stride: 1}, // '&' ''' '(' - ';'. '''/27 used to be excluded but an errata is in the review process https://www.rfc-editor.org/errata/eid6937 + {Lo: 0x003D, Hi: 0x003D, Stride: 1}, // '=' + {Lo: 0x003F, Hi: 0x005B, Stride: 1}, // '?' - '[' + {Lo: 0x005D, Hi: 0x005D, Stride: 1}, // ']' + {Lo: 0x005F, Hi: 0x005F, Stride: 1}, // '_' + {Lo: 0x0061, Hi: 0x007A, Stride: 1}, // 'a' - 'z' + {Lo: 0x007E, Hi: 0x007E, Stride: 1}, // '~' + {Lo: 0x00A0, Hi: 0xD7FF, Stride: 1}, // ucschar + {Lo: 0xE000, Hi: 0xF8FF, Stride: 1}, // iprivate + {Lo: 0xF900, Hi: 0xFDCF, Stride: 1}, // ucschar + {Lo: 0xFDF0, Hi: 0xFFEF, Stride: 1}, // ucschar + }, + R32: []unicode.Range32{ + {Lo: 0x00010000, Hi: 0x0001FFFD, Stride: 1}, // ucschar + {Lo: 0x00020000, Hi: 0x0002FFFD, Stride: 1}, // ucschar + {Lo: 0x00030000, Hi: 0x0003FFFD, Stride: 1}, // ucschar + {Lo: 0x00040000, Hi: 0x0004FFFD, Stride: 1}, // ucschar + {Lo: 0x00050000, Hi: 0x0005FFFD, Stride: 1}, // ucschar + {Lo: 0x00060000, Hi: 0x0006FFFD, Stride: 1}, // ucschar + {Lo: 0x00070000, Hi: 0x0007FFFD, Stride: 1}, // ucschar + {Lo: 0x00080000, Hi: 0x0008FFFD, Stride: 1}, // ucschar + {Lo: 0x00090000, Hi: 0x0009FFFD, Stride: 1}, // ucschar + {Lo: 0x000A0000, Hi: 0x000AFFFD, Stride: 1}, // ucschar + {Lo: 0x000B0000, Hi: 0x000BFFFD, Stride: 1}, // ucschar + {Lo: 0x000C0000, Hi: 0x000CFFFD, Stride: 1}, // ucschar + {Lo: 0x000D0000, Hi: 0x000DFFFD, Stride: 1}, // ucschar + {Lo: 0x000E1000, Hi: 0x000EFFFD, Stride: 1}, // ucschar + {Lo: 0x000F0000, Hi: 0x000FFFFD, Stride: 1}, // iprivate + {Lo: 0x00100000, Hi: 0x0010FFFD, Stride: 1}, // iprivate + }, + LatinOffset: 10, + } +) + +type parser struct { + r string + start int + stop int + state parseState +} + +func (p *parser) errorf(i rune, format string, a ...interface{}) error { + return fmt.Errorf("%s: %s%s", fmt.Sprintf(format, a...), p.r[0:p.stop], string(i)) +} + +func (p *parser) rune() (rune, int) { + r, size := utf8.DecodeRuneInString(p.r[p.stop:]) + if r != utf8.RuneError { + p.stop += size + } + return r, size +} + +func (p *parser) unread(r rune) { + p.stop -= utf8.RuneLen(r) +} + +type parseState int + +const ( + parseStateDefault = parseState(iota) + parseStateOperator + parseStateVarList + parseStateVarName + parseStatePrefix +) + +func (p *parser) setState(state parseState) { + p.state = state + p.start = p.stop +} + +func (p *parser) parseURITemplate() (*Template, error) { + tmpl := Template{ + raw: p.r, + exprs: []template{}, + } + + var exp *expression + for { + r, size := p.rune() + if r == utf8.RuneError { + if size == 0 { + if p.state != parseStateDefault { + return nil, p.errorf('_', "incomplete expression") + } + if p.start < p.stop { + tmpl.exprs = append(tmpl.exprs, literals(p.r[p.start:p.stop])) + } + return &tmpl, nil + } + return nil, p.errorf('_', "invalid UTF-8 sequence") + } + + switch p.state { + case parseStateDefault: + switch r { + case '{': + if stop := p.stop - size; stop > p.start { + tmpl.exprs = append(tmpl.exprs, literals(p.r[p.start:stop])) + } + exp = &expression{} + tmpl.exprs = append(tmpl.exprs, exp) + p.setState(parseStateOperator) + case '%': + p.unread(r) + if err := p.consumeTriplet(); err != nil { + return nil, err + } + default: + if !unicode.Is(rangeLiterals, r) { + p.unread(r) + return nil, p.errorf('_', "unacceptable character (hint: use %%XX encoding)") + } + } + case parseStateOperator: + switch r { + default: + p.unread(r) + exp.op = parseOpSimple + case '+': + exp.op = parseOpPlus + case '#': + exp.op = parseOpCrosshatch + case '.': + exp.op = parseOpDot + case '/': + exp.op = parseOpSlash + case ';': + exp.op = parseOpSemicolon + case '?': + exp.op = parseOpQuestion + case '&': + exp.op = parseOpAmpersand + case '=', ',', '!', '@', '|': // op-reserved + return nil, p.errorf('|', "unimplemented operator (op-reserved)") + } + p.setState(parseStateVarName) + case parseStateVarList: + switch r { + case ',': + p.setState(parseStateVarName) + case '}': + exp.init() + p.setState(parseStateDefault) + default: + p.unread(r) + return nil, p.errorf('_', "unrecognized value modifier") + } + case parseStateVarName: + switch r { + case ':', '*': + name := p.r[p.start : p.stop-size] + if !isValidVarname(name) { + return nil, p.errorf('|', "unacceptable variable name") + } + explode := r == '*' + exp.vars = append(exp.vars, varspec{ + name: name, + explode: explode, + }) + if explode { + p.setState(parseStateVarList) + } else { + p.setState(parseStatePrefix) + } + case ',', '}': + p.unread(r) + name := p.r[p.start:p.stop] + if !isValidVarname(name) { + return nil, p.errorf('|', "unacceptable variable name") + } + exp.vars = append(exp.vars, varspec{ + name: name, + }) + p.setState(parseStateVarList) + case '%': + p.unread(r) + if err := p.consumeTriplet(); err != nil { + return nil, err + } + case '.': + if dot := p.stop - size; dot == p.start || p.r[dot-1] == '.' { + return nil, p.errorf('|', "unacceptable variable name") + } + default: + if !unicode.Is(rangeVarchar, r) { + p.unread(r) + return nil, p.errorf('_', "unacceptable variable name") + } + } + case parseStatePrefix: + spec := &(exp.vars[len(exp.vars)-1]) + switch { + case '0' <= r && r <= '9': + spec.maxlen *= 10 + spec.maxlen += int(r - '0') + if spec.maxlen == 0 || spec.maxlen > 9999 { + return nil, p.errorf('|', "max-length must be (0, 9999]") + } + default: + p.unread(r) + if spec.maxlen == 0 { + return nil, p.errorf('_', "max-length must be (0, 9999]") + } + p.setState(parseStateVarList) + } + default: + p.unread(r) + panic(p.errorf('_', "unhandled parseState(%d)", p.state)) + } + } +} + +func isValidVarname(name string) bool { + if l := len(name); l == 0 || name[0] == '.' || name[l-1] == '.' { + return false + } + for i := 1; i < len(name)-1; i++ { + switch c := name[i]; c { + case '.': + if name[i-1] == '.' { + return false + } + } + } + return true +} + +func (p *parser) consumeTriplet() error { + if len(p.r)-p.stop < 3 || p.r[p.stop] != '%' || !ishex(p.r[p.stop+1]) || !ishex(p.r[p.stop+2]) { + return p.errorf('_', "incomplete pct-encodeed") + } + p.stop += 3 + return nil +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/prog.go b/vendor/github.com/yosida95/uritemplate/v3/prog.go new file mode 100644 index 0000000000..97af4f0eab --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/prog.go @@ -0,0 +1,130 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "bytes" + "strconv" +) + +type progOpcode uint16 + +const ( + // match + opRune progOpcode = iota + opRuneClass + opLineBegin + opLineEnd + // capture + opCapStart + opCapEnd + // stack + opSplit + opJmp + opJmpIfNotDefined + opJmpIfNotEmpty + opJmpIfNotFirst + // result + opEnd + // fake + opNoop + opcodeMax +) + +var opcodeNames = []string{ + // match + "opRune", + "opRuneClass", + "opLineBegin", + "opLineEnd", + // capture + "opCapStart", + "opCapEnd", + // stack + "opSplit", + "opJmp", + "opJmpIfNotDefined", + "opJmpIfNotEmpty", + "opJmpIfNotFirst", + // result + "opEnd", +} + +func (code progOpcode) String() string { + if code >= opcodeMax { + return "" + } + return opcodeNames[code] +} + +type progOp struct { + code progOpcode + r rune + rc runeClass + i uint32 + + name string +} + +func dumpProgOp(b *bytes.Buffer, op *progOp) { + b.WriteString(op.code.String()) + switch op.code { + case opRune: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(string(op.r))) + b.WriteString(")") + case opRuneClass: + b.WriteString("(") + b.WriteString(op.rc.String()) + b.WriteString(")") + case opCapStart, opCapEnd: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(op.name)) + b.WriteString(")") + case opSplit: + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + case opJmp, opJmpIfNotFirst: + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + case opJmpIfNotDefined, opJmpIfNotEmpty: + b.WriteString("(") + b.WriteString(strconv.QuoteToASCII(op.name)) + b.WriteString(")") + b.WriteString(" -> ") + b.WriteString(strconv.FormatInt(int64(op.i), 10)) + } +} + +type prog struct { + op []progOp + numCap int +} + +func dumpProg(b *bytes.Buffer, prog *prog, pc uint32) { + for i := range prog.op { + op := prog.op[i] + + pos := strconv.Itoa(i) + if uint32(i) == pc { + pos = "*" + pos + } + b.WriteString(" "[len(pos):]) + b.WriteString(pos) + + b.WriteByte('\t') + dumpProgOp(b, &op) + + b.WriteByte('\n') + } +} + +func (p *prog) String() string { + b := bytes.Buffer{} + dumpProg(&b, p, 0) + return b.String() +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go b/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go new file mode 100644 index 0000000000..dbd2673753 --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/uritemplate.go @@ -0,0 +1,116 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import ( + "log" + "regexp" + "strings" + "sync" +) + +var ( + debug = debugT(false) +) + +type debugT bool + +func (t debugT) Printf(format string, v ...interface{}) { + if t { + log.Printf(format, v...) + } +} + +// Template represents a URI Template. +type Template struct { + raw string + exprs []template + + // protects the rest of fields + mu sync.Mutex + varnames []string + re *regexp.Regexp + prog *prog +} + +// New parses and constructs a new Template instance based on the template. +// New returns an error if the template cannot be recognized. +func New(template string) (*Template, error) { + return (&parser{r: template}).parseURITemplate() +} + +// MustNew panics if the template cannot be recognized. +func MustNew(template string) *Template { + ret, err := New(template) + if err != nil { + panic(err) + } + return ret +} + +// Raw returns a raw URI template passed to New in string. +func (t *Template) Raw() string { + return t.raw +} + +// Varnames returns variable names used in the template. +func (t *Template) Varnames() []string { + t.mu.Lock() + defer t.mu.Unlock() + if t.varnames != nil { + return t.varnames + } + + reg := map[string]struct{}{} + t.varnames = []string{} + for i := range t.exprs { + expr, ok := t.exprs[i].(*expression) + if !ok { + continue + } + for _, spec := range expr.vars { + if _, ok := reg[spec.name]; ok { + continue + } + reg[spec.name] = struct{}{} + t.varnames = append(t.varnames, spec.name) + } + } + + return t.varnames +} + +// Expand returns a URI reference corresponding to the template expanded using the passed variables. +func (t *Template) Expand(vars Values) (string, error) { + var w strings.Builder + for i := range t.exprs { + expr := t.exprs[i] + if err := expr.expand(&w, vars); err != nil { + return w.String(), err + } + } + return w.String(), nil +} + +// Regexp converts the template to regexp and returns compiled *regexp.Regexp. +func (t *Template) Regexp() *regexp.Regexp { + t.mu.Lock() + defer t.mu.Unlock() + if t.re != nil { + return t.re + } + + var b strings.Builder + b.WriteByte('^') + for _, expr := range t.exprs { + expr.regexp(&b) + } + b.WriteByte('$') + t.re = regexp.MustCompile(b.String()) + + return t.re +} diff --git a/vendor/github.com/yosida95/uritemplate/v3/value.go b/vendor/github.com/yosida95/uritemplate/v3/value.go new file mode 100644 index 0000000000..0550eabdbf --- /dev/null +++ b/vendor/github.com/yosida95/uritemplate/v3/value.go @@ -0,0 +1,216 @@ +// Copyright (C) 2016 Kohei YOSHIDA. All rights reserved. +// +// This program is free software; you can redistribute it and/or +// modify it under the terms of The BSD 3-Clause License +// that can be found in the LICENSE file. + +package uritemplate + +import "strings" + +// A varname containing pct-encoded characters is not the same variable as +// a varname with those same characters decoded. +// +// -- https://tools.ietf.org/html/rfc6570#section-2.3 +type Values map[string]Value + +func (v Values) Set(name string, value Value) { + v[name] = value +} + +func (v Values) Get(name string) Value { + if v == nil { + return Value{} + } + return v[name] +} + +type ValueType uint8 + +const ( + ValueTypeString = iota + ValueTypeList + ValueTypeKV + valueTypeLast +) + +var valueTypeNames = []string{ + "String", + "List", + "KV", +} + +func (vt ValueType) String() string { + if vt < valueTypeLast { + return valueTypeNames[vt] + } + return "" +} + +type Value struct { + T ValueType + V []string +} + +func (v Value) String() string { + if v.Valid() && v.T == ValueTypeString { + return v.V[0] + } + return "" +} + +func (v Value) List() []string { + if v.Valid() && v.T == ValueTypeList { + return v.V + } + return nil +} + +func (v Value) KV() []string { + if v.Valid() && v.T == ValueTypeKV { + return v.V + } + return nil +} + +func (v Value) Valid() bool { + switch v.T { + default: + return false + case ValueTypeString: + return len(v.V) > 0 + case ValueTypeList: + return len(v.V) > 0 + case ValueTypeKV: + return len(v.V) > 0 && len(v.V)%2 == 0 + } +} + +func (v Value) expand(w *strings.Builder, spec varspec, exp *expression) error { + switch v.T { + case ValueTypeString: + val := v.V[0] + var maxlen int + if max := len(val); spec.maxlen < 1 || spec.maxlen > max { + maxlen = max + } else { + maxlen = spec.maxlen + } + + if exp.named { + w.WriteString(spec.name) + if val == "" { + w.WriteString(exp.ifemp) + return nil + } + w.WriteByte('=') + } + return exp.escape(w, val[:maxlen]) + case ValueTypeList: + var sep string + if spec.explode { + sep = exp.sep + } else { + sep = "," + } + + var pre string + var preifemp string + if spec.explode && exp.named { + pre = spec.name + "=" + preifemp = spec.name + exp.ifemp + } + + if !spec.explode && exp.named { + w.WriteString(spec.name) + w.WriteByte('=') + } + for i := range v.V { + val := v.V[i] + if i > 0 { + w.WriteString(sep) + } + if val == "" { + w.WriteString(preifemp) + continue + } + w.WriteString(pre) + + if err := exp.escape(w, val); err != nil { + return err + } + } + case ValueTypeKV: + var sep string + var kvsep string + if spec.explode { + sep = exp.sep + kvsep = "=" + } else { + sep = "," + kvsep = "," + } + + var ifemp string + var kescape escapeFunc + if spec.explode && exp.named { + ifemp = exp.ifemp + kescape = escapeLiteral + } else { + ifemp = "," + kescape = exp.escape + } + + if !spec.explode && exp.named { + w.WriteString(spec.name) + w.WriteByte('=') + } + + for i := 0; i < len(v.V); i += 2 { + if i > 0 { + w.WriteString(sep) + } + if err := kescape(w, v.V[i]); err != nil { + return err + } + if v.V[i+1] == "" { + w.WriteString(ifemp) + continue + } + w.WriteString(kvsep) + + if err := exp.escape(w, v.V[i+1]); err != nil { + return err + } + } + } + return nil +} + +// String returns Value that represents string. +func String(v string) Value { + return Value{ + T: ValueTypeString, + V: []string{v}, + } +} + +// List returns Value that represents list. +func List(v ...string) Value { + return Value{ + T: ValueTypeList, + V: v, + } +} + +// KV returns Value that represents associative list. +// KV panics if len(kv) is not even. +func KV(kv ...string) Value { + if len(kv)%2 != 0 { + panic("uritemplate.go: count of the kv must be even number") + } + return Value{ + T: ValueTypeKV, + V: kv, + } +} diff --git a/vendor/modules.txt b/vendor/modules.txt index d2450c00eb..7cb77b7118 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -275,6 +275,9 @@ github.com/aws/aws-sdk-go/service/sso/ssoiface github.com/aws/aws-sdk-go/service/sts github.com/aws/aws-sdk-go/service/sts/stsiface github.com/aws/aws-sdk-go/service/wafv2 +# github.com/bahlo/generic-list-go v0.2.0 +## explicit; go 1.18 +github.com/bahlo/generic-list-go # github.com/basgys/goxml2json v1.1.1-0.20181031222924-996d9fc8d313 ## explicit github.com/basgys/goxml2json @@ -306,6 +309,9 @@ github.com/boombuler/barcode/utils # github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2 ## explicit github.com/bradfitz/iter +# github.com/buger/jsonparser v1.1.1 +## explicit; go 1.13 +github.com/buger/jsonparser # github.com/c-bata/go-prompt v0.2.4 ## explicit github.com/c-bata/go-prompt @@ -542,8 +548,6 @@ github.com/fernet/fernet-go # github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568 ## explicit github.com/flynn/go-shlex -# github.com/frankban/quicktest v1.14.3 -## explicit; go 1.13 # github.com/fsnotify/fsnotify v1.4.9 ## explicit; go 1.13 github.com/fsnotify/fsnotify @@ -782,6 +786,9 @@ github.com/influxdata/promql/v2 github.com/influxdata/promql/v2/pkg/labels github.com/influxdata/promql/v2/pkg/value github.com/influxdata/promql/v2/util/strutil +# github.com/invopop/jsonschema v0.13.0 +## explicit; go 1.18 +github.com/invopop/jsonschema # github.com/jaypipes/ghw v0.11.0 ## explicit; go 1.18 github.com/jaypipes/ghw/pkg/context @@ -892,6 +899,15 @@ github.com/lufia/plan9stats # github.com/ma314smith/signedxml v0.0.0-20210628192057-abc5b481ae1c ## explicit github.com/ma314smith/signedxml +# github.com/mailru/easyjson v0.7.7 +## explicit; go 1.12 +github.com/mailru/easyjson/buffer +github.com/mailru/easyjson/jwriter +# github.com/mark3labs/mcp-go v0.39.1 +## explicit; go 1.23 +github.com/mark3labs/mcp-go/mcp +github.com/mark3labs/mcp-go/server +github.com/mark3labs/mcp-go/util # github.com/mattn/go-colorable v0.1.9 ## explicit; go 1.13 github.com/mattn/go-colorable @@ -1174,6 +1190,9 @@ github.com/smartystreets/goconvey/convey/reporting # github.com/spaolacci/murmur3 v1.1.0 ## explicit github.com/spaolacci/murmur3 +# github.com/spf13/cast v1.7.1 +## explicit; go 1.19 +github.com/spf13/cast # github.com/spf13/pflag v1.0.5 ## explicit; go 1.12 github.com/spf13/pflag @@ -1282,6 +1301,9 @@ github.com/willf/bitset # github.com/willf/bloom v2.0.3+incompatible ## explicit github.com/willf/bloom +# github.com/wk8/go-ordered-map/v2 v2.1.8 +## explicit; go 1.18 +github.com/wk8/go-ordered-map/v2 # github.com/xuri/efp v0.0.0-20220603152613-6918739fd470 ## explicit; go 1.11 github.com/xuri/efp @@ -1291,6 +1313,9 @@ github.com/xuri/excelize/v2 # github.com/xuri/nfp v0.0.0-20220409054826-5e722a1d9e22 ## explicit; go 1.15 github.com/xuri/nfp +# github.com/yosida95/uritemplate/v3 v3.0.2 +## explicit; go 1.14 +github.com/yosida95/uritemplate/v3 # github.com/yusufpapurcu/wmi v1.2.2 ## explicit; go 1.16 github.com/yusufpapurcu/wmi @@ -1852,7 +1877,7 @@ sigs.k8s.io/structured-merge-diff/v4/value # sigs.k8s.io/yaml v1.2.0 ## explicit; go 1.12 sigs.k8s.io/yaml -# yunion.io/x/cloudmux v0.3.10-0-alpha.1.0.20250912144144-d0d8cf049d7f +# yunion.io/x/cloudmux v0.3.10-0-alpha.1.0.20250915054625-7251d9eeceec ## explicit; go 1.24 yunion.io/x/cloudmux/pkg/apis yunion.io/x/cloudmux/pkg/apis/billing diff --git a/vendor/yunion.io/x/cloudmux/pkg/multicloud/aws/dbinstance_cluster.go b/vendor/yunion.io/x/cloudmux/pkg/multicloud/aws/dbinstance_cluster.go index 197ce60ffc..fc11184222 100644 --- a/vendor/yunion.io/x/cloudmux/pkg/multicloud/aws/dbinstance_cluster.go +++ b/vendor/yunion.io/x/cloudmux/pkg/multicloud/aws/dbinstance_cluster.go @@ -172,7 +172,7 @@ func (rds *SDBInstanceCluster) GetMaintainTime() string { } func (rds *SDBInstanceCluster) GetConnectionStr() string { - return "" + return rds.Endpoint } func (rds *SDBInstanceCluster) GetInternalConnectionStr() string {