mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-07 06:01:20 +08:00
Merge pull request #1457 from Jacksonxhx/milvus
Integrated Milvus with MetaGPT
This commit is contained in:
99
metagpt/document_store/milvus_store.py
Normal file
99
metagpt/document_store/milvus_store.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
|
||||
|
||||
@dataclass
|
||||
class MilvusConnection:
|
||||
"""
|
||||
Args:
|
||||
uri: milvus url
|
||||
token: milvus token
|
||||
"""
|
||||
|
||||
uri: str = None
|
||||
token: str = None
|
||||
|
||||
|
||||
class MilvusStore(BaseStore):
|
||||
def __init__(self, connect: MilvusConnection):
|
||||
try:
|
||||
from pymilvus import MilvusClient
|
||||
except ImportError:
|
||||
raise Exception("Please install pymilvus first.")
|
||||
if not connect.uri:
|
||||
raise Exception("please check MilvusConnection, uri must be set.")
|
||||
self.client = MilvusClient(uri=connect.uri, token=connect.token)
|
||||
|
||||
def create_collection(self, collection_name: str, dim: int, enable_dynamic_schema: bool = True):
|
||||
from pymilvus import DataType
|
||||
|
||||
if self.client.has_collection(collection_name=collection_name):
|
||||
self.client.drop_collection(collection_name=collection_name)
|
||||
|
||||
schema = self.client.create_schema(
|
||||
auto_id=False,
|
||||
enable_dynamic_field=False,
|
||||
)
|
||||
schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=36)
|
||||
schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dim)
|
||||
|
||||
index_params = self.client.prepare_index_params()
|
||||
index_params.add_index(field_name="vector", index_type="AUTOINDEX", metric_type="COSINE")
|
||||
|
||||
self.client.create_collection(
|
||||
collection_name=collection_name,
|
||||
schema=schema,
|
||||
index_params=index_params,
|
||||
enable_dynamic_schema=enable_dynamic_schema,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_filter(key, value) -> str:
|
||||
if isinstance(value, str):
|
||||
filter_expression = f'{key} == "{value}"'
|
||||
else:
|
||||
if isinstance(value, list):
|
||||
filter_expression = f"{key} in {value}"
|
||||
else:
|
||||
filter_expression = f"{key} == {value}"
|
||||
|
||||
return filter_expression
|
||||
|
||||
def search(
|
||||
self,
|
||||
collection_name: str,
|
||||
query: List[float],
|
||||
filter: Dict = None,
|
||||
limit: int = 10,
|
||||
output_fields: Optional[List[str]] = None,
|
||||
) -> List[dict]:
|
||||
filter_expression = " and ".join([self.build_filter(key, value) for key, value in filter.items()])
|
||||
print(filter_expression)
|
||||
|
||||
res = self.client.search(
|
||||
collection_name=collection_name,
|
||||
data=[query],
|
||||
filter=filter_expression,
|
||||
limit=limit,
|
||||
output_fields=output_fields,
|
||||
)[0]
|
||||
|
||||
return res
|
||||
|
||||
def add(self, collection_name: str, _ids: List[str], vector: List[List[float]], metadata: List[Dict[str, Any]]):
|
||||
data = dict()
|
||||
|
||||
for i, id in enumerate(_ids):
|
||||
data["id"] = id
|
||||
data["vector"] = vector[i]
|
||||
data["metadata"] = metadata[i]
|
||||
|
||||
self.client.upsert(collection_name=collection_name, data=data)
|
||||
|
||||
def delete(self, collection_name: str, _ids: List[str]):
|
||||
self.client.delete(collection_name=collection_name, ids=_ids)
|
||||
|
||||
def write(self, *args, **kwargs):
|
||||
pass
|
||||
@@ -8,6 +8,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.schema import (
|
||||
@@ -17,6 +18,7 @@ from metagpt.rag.schema import (
|
||||
ElasticsearchIndexConfig,
|
||||
ElasticsearchKeywordIndexConfig,
|
||||
FAISSIndexConfig,
|
||||
MilvusIndexConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -28,6 +30,7 @@ class RAGIndexFactory(ConfigBasedFactory):
|
||||
BM25IndexConfig: self._create_bm25,
|
||||
ElasticsearchIndexConfig: self._create_es,
|
||||
ElasticsearchKeywordIndexConfig: self._create_es,
|
||||
MilvusIndexConfig: self._create_milvus
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
@@ -46,6 +49,11 @@ class RAGIndexFactory(ConfigBasedFactory):
|
||||
|
||||
return self._index_from_storage(storage_context=storage_context, config=config, **kwargs)
|
||||
|
||||
def _create_milvus(self, config: MilvusIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = MilvusVectorStore(collection_name=config.collection_name, uri=config.uri, token=config.token)
|
||||
|
||||
return self._index_from_vector_store(vector_store=vector_store, config=config, **kwargs)
|
||||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
db = chromadb.PersistentClient(str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name, metadata=config.metadata)
|
||||
|
||||
@@ -12,6 +12,7 @@ from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
@@ -20,6 +21,7 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
@@ -27,6 +29,7 @@ from metagpt.rag.schema import (
|
||||
ElasticsearchKeywordRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
MilvusRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -56,6 +59,7 @@ class RetrieverFactory(ConfigBasedFactory):
|
||||
ChromaRetrieverConfig: self._create_chroma_retriever,
|
||||
ElasticsearchRetrieverConfig: self._create_es_retriever,
|
||||
ElasticsearchKeywordRetrieverConfig: self._create_es_retriever,
|
||||
MilvusRetrieverConfig: self._create_milvus_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
@@ -76,6 +80,11 @@ class RetrieverFactory(ConfigBasedFactory):
|
||||
|
||||
return index.as_retriever()
|
||||
|
||||
def _create_milvus_retriever(self, config: MilvusRetrieverConfig, **kwargs) -> MilvusRetriever:
|
||||
config.index = self._build_milvus_index(config, **kwargs)
|
||||
|
||||
return MilvusRetriever(**config.model_dump())
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
config.index = self._build_faiss_index(config, **kwargs)
|
||||
|
||||
@@ -128,6 +137,12 @@ class RetrieverFactory(ConfigBasedFactory):
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_milvus_index(self, config: MilvusRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = MilvusVectorStore(uri=config.uri, collection_name=config.collection_name, token=config.token, dim=config.dimensions)
|
||||
|
||||
return self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
|
||||
@get_or_build_index
|
||||
def _build_es_index(self, config: ElasticsearchRetrieverConfig, **kwargs) -> VectorStoreIndex:
|
||||
vector_store = ElasticsearchStore(**config.store_config.model_dump())
|
||||
|
||||
17
metagpt/rag/retrievers/milvus_retriever.py
Normal file
17
metagpt/rag/retrievers/milvus_retriever.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Milvus retriever."""
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class MilvusRetriever(VectorIndexRetriever):
|
||||
"""Milvus retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist.
|
||||
|
||||
Milvus automatically saves, so there is no need to implement."""
|
||||
@@ -8,7 +8,7 @@ from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator, validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
@@ -62,6 +62,36 @@ class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class MilvusRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Milvus-based retrievers."""
|
||||
|
||||
uri: str = Field(default="./milvus_local.db", description="The directory to save data.")
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
token: str = Field(default=None, description="The token for Milvus")
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
dimensions: int = Field(default=0, description="Dimensionality of the vectors for Milvus index construction.")
|
||||
|
||||
_embedding_type_to_dimensions: ClassVar[dict[EmbeddingType, int]] = {
|
||||
EmbeddingType.GEMINI: 768,
|
||||
EmbeddingType.OLLAMA: 4096,
|
||||
}
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_dimensions(self):
|
||||
if self.dimensions == 0:
|
||||
self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get(
|
||||
config.embedding.api_type, 1536
|
||||
)
|
||||
if not config.embedding.dimensions and config.embedding.api_type not in self._embedding_type_to_dimensions:
|
||||
logger.warning(
|
||||
f"You didn't set dimensions in config when using {config.embedding.api_type}, default to 1536"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class ChromaRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Chroma-based retrievers."""
|
||||
|
||||
@@ -169,6 +199,16 @@ class ChromaIndexConfig(VectorIndexConfig):
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
class MilvusIndexConfig(VectorIndexConfig):
|
||||
"""Config for milvus-based index."""
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
uri: str = Field(default="./milvus_local.db", description="The uri of the index.")
|
||||
token: Optional[str] = Field(default=None, description="The token of the index.")
|
||||
metadata: Optional[CollectionMetadata] = Field(
|
||||
default=None, description="Optional metadata to associate with the collection"
|
||||
)
|
||||
|
||||
|
||||
class BM25IndexConfig(BaseIndexConfig):
|
||||
"""Config for bm25-based index."""
|
||||
|
||||
@@ -19,7 +19,7 @@ beautifulsoup4==4.12.3
|
||||
pandas==2.1.1
|
||||
pydantic>=2.5.3
|
||||
#pygame==2.1.3
|
||||
#pymilvus==2.2.8
|
||||
# pymilvus==2.4.6
|
||||
# pytest==7.2.2 # test extras require
|
||||
python_docx==0.8.11
|
||||
PyYAML==6.0.1
|
||||
@@ -78,4 +78,4 @@ volcengine-python-sdk[ark]~=1.0.94
|
||||
gymnasium==0.29.1
|
||||
boto3~=1.34.69
|
||||
spark_ai_python~=0.3.30
|
||||
agentops
|
||||
agentops
|
||||
1
setup.py
1
setup.py
@@ -43,6 +43,7 @@ extras_require = {
|
||||
"llama-index-postprocessor-cohere-rerank==0.1.4",
|
||||
"llama-index-postprocessor-colbert-rerank==0.1.1",
|
||||
"llama-index-postprocessor-flag-embedding-reranker==0.1.2",
|
||||
# "llama-index-vector-stores-milvus==0.1.23",
|
||||
"docx2txt==0.8",
|
||||
],
|
||||
}
|
||||
|
||||
48
tests/metagpt/document_store/test_milvus_store.py
Normal file
48
tests/metagpt/document_store/test_milvus_store.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
|
||||
|
||||
seed_value = 42
|
||||
random.seed(seed_value)
|
||||
|
||||
vectors = [[random.random() for _ in range(8)] for _ in range(10)]
|
||||
ids = [f"doc_{i}" for i in range(10)]
|
||||
metadata = [{"color": "red", "rand_number": i % 10} for i in range(10)]
|
||||
|
||||
|
||||
def assert_almost_equal(actual, expected):
|
||||
delta = 1e-10
|
||||
if isinstance(expected, list):
|
||||
assert len(actual) == len(expected)
|
||||
for ac, exp in zip(actual, expected):
|
||||
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
|
||||
else:
|
||||
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"
|
||||
|
||||
|
||||
@pytest.mark.skip() # Skip because the pymilvus dependency is not installed by default
|
||||
def test_milvus_store():
|
||||
milvus_connection = MilvusConnection(uri="./milvus_local.db")
|
||||
milvus_store = MilvusStore(milvus_connection)
|
||||
|
||||
collection_name = "TestCollection"
|
||||
milvus_store.create_collection(collection_name, dim=8)
|
||||
|
||||
milvus_store.add(collection_name, ids, vectors, metadata)
|
||||
|
||||
search_results = milvus_store.search(collection_name, query=[1.0] * 8)
|
||||
assert len(search_results) > 0
|
||||
first_result = search_results[0]
|
||||
assert first_result["id"] == "doc_0"
|
||||
|
||||
search_results_with_filter = milvus_store.search(collection_name, query=[1.0] * 8, filter={"rand_number": 1})
|
||||
assert len(search_results_with_filter) > 0
|
||||
assert search_results_with_filter[0]["id"] == "doc_1"
|
||||
|
||||
milvus_store.delete(collection_name, _ids=["doc_0"])
|
||||
deleted_results = milvus_store.search(collection_name, query=[1.0] * 8, limit=1)
|
||||
assert deleted_results[0]["id"] != "doc_0"
|
||||
|
||||
milvus_store.client.drop_collection(collection_name)
|
||||
@@ -7,7 +7,7 @@ from metagpt.rag.schema import (
|
||||
ChromaIndexConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSIndexConfig,
|
||||
FAISSIndexConfig, MilvusIndexConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,6 +20,10 @@ class TestRAGIndexFactory:
|
||||
def faiss_config(self):
|
||||
return FAISSIndexConfig(persist_path="")
|
||||
|
||||
@pytest.fixture
|
||||
def milvus_config(self):
|
||||
return MilvusIndexConfig(uri="", collection_name="")
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_config(self):
|
||||
return ChromaIndexConfig(persist_path="", collection_name="")
|
||||
@@ -65,6 +69,16 @@ class TestRAGIndexFactory:
|
||||
):
|
||||
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
|
||||
|
||||
def test_create_milvus_index(self, mocker, milvus_config, mock_from_vector_store, mock_embedding):
|
||||
# Mock
|
||||
mock_milvus_store = mocker.patch("metagpt.rag.factories.index.MilvusVectorStore")
|
||||
|
||||
# Exec
|
||||
self.index_factory.get_index(milvus_config, embed_model=mock_embedding)
|
||||
|
||||
# Assert
|
||||
mock_milvus_store.assert_called_once()
|
||||
|
||||
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
|
||||
# Mock
|
||||
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")
|
||||
|
||||
@@ -5,6 +5,7 @@ from llama_index.core.embeddings import MockEmbedding
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
from llama_index.vector_stores.milvus import MilvusVectorStore
|
||||
|
||||
from metagpt.rag.factories.retriever import RetrieverFactory
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
@@ -12,12 +13,14 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSRetrieverConfig,
|
||||
MilvusRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -41,6 +44,10 @@ class TestRetrieverFactory:
|
||||
def mock_chroma_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ChromaVectorStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_milvus_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=MilvusVectorStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_es_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ElasticsearchStore)
|
||||
@@ -91,6 +98,14 @@ class TestRetrieverFactory:
|
||||
|
||||
assert isinstance(retriever, ChromaRetriever)
|
||||
|
||||
def test_get_retriever_with_milvus_config(self, mocker, mock_milvus_vector_store, mock_embedding):
|
||||
mock_config = MilvusRetrieverConfig(uri="/path/to/milvus.db", collection_name="test_collection")
|
||||
mocker.patch("metagpt.rag.factories.retriever.MilvusVectorStore", return_value=mock_milvus_vector_store)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, MilvusRetriever)
|
||||
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
|
||||
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
|
||||
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
|
||||
|
||||
Reference in New Issue
Block a user