Skip to content

Newest Embedding model #8

@chengzicong20040913

Description

@chengzicong20040913
class EmbeddingModel:
    """Base class for Embedding models"""
    def __init__(self, model: str,openai_api_key: str = None, openai_api_base: str = None):
        self.model_name = model
        self.key = openai_api_key
        self.api_base = openai_api_base
        self.client = None
        if self.model_name == 'OpenAI':
            self.client = OpenAIEmbeddings(
                model=self.model_name,
                openai_api_key=self.key,
                openai_api_base=self.api_base
            )
        else:
            self.client = OpenAI(
                api_key=self.key,
                base_url=self.api_base,
            )

    def embed_query(self, text: str):
        if self.model_name == 'OpenAI':
            return self.client.embed_query(text)
        else:
            completion = self.client.embeddings.create(
                model=self.model_name,
                input=[text],
                dimensions=1024,
                encoding_format="float"
            )
            output=completion.model_dump_json()
            output = json.loads(output)
            # 1. 先按index排序
            sorted_data = sorted(output["data"], key=lambda x: int(x["index"]))
            embeddings = [item["embedding"] for item in sorted_data]
            return embeddings[0]
            

    def embed_documents(self, texts: List[str]) -> List[np.ndarray]:
        for text in texts:
            if self.model_name == 'OpenAI':
                return self.client.embed_documents(texts)
            else:
                #按照10,10分片的方式进行分割
                for i in range(0, len(texts), 10):
                    chunk = texts[i:i + 10]
                    #对于chunk需要对于每一个条目裁剪成最大8192个字符
                    chunk = [text[:16384] for text in chunk]
                    # 2. 分片处理
                    completion = self.client.embeddings.create(
                        model=self.model_name,
                        input=chunk,
                        dimensions=1024,
                        encoding_format="float"
                    )
                    output=completion.model_dump_json()
                    output = json.loads(output)
                    # 1. 先按index排序
                    sorted_data = sorted(output["data"], key=lambda x: int(x["index"]))
                    embeddings = [item["embedding"] for item in sorted_data]
                    # 3. 拼接
                    if i == 0:
                        all_embeddings = embeddings
                    else:
                        all_embeddings.extend(embeddings)
                return all_embeddings
            
    async def aembed_query(self, text: str) -> List[float]:
        """异步生成单个文本的embedding"""
        if self.model_name == 'OpenAI':
            # 假设OpenAIEmbeddings有异步方法
            return await self.client.aembed_query(text)
        else:
            async with aiohttp.ClientSession() as session:
                headers = {
                    "Authorization": f"Bearer {self.key}",
                    "Content-Type": "application/json"
                }
                payload = {
                    "model": self.model_name,
                    "input": [text],
                    "dimensions": 1024,
                    "encoding_format": "float"
                }
                
                async with session.post(
                    f"{self.api_base}/embeddings",
                    headers=headers,
                    json=payload
                ) as response:
                    output = await response.json()
                    sorted_data = sorted(output["data"], key=lambda x: int(x["index"]))
                    return [item["embedding"] for item in sorted_data][0]

    async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
        """异步批量生成embedding"""
        if self.model_name == 'OpenAI':
            return await self.client.aembed_documents(texts)
        else:
            # 使用asyncio.gather并发处理
            tasks = [self.aembed_query(text) for text in texts]
            return await asyncio.gather(*tasks)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions