Забыл батчить, теперь ок

This commit is contained in:
Arxip222 2025-12-24 16:17:50 +03:00
parent 683f779c31
commit 42fcc0eb16
4 changed files with 26 additions and 6 deletions

View File

@ -29,7 +29,7 @@ class RAGService:
self.splitter = splitter or TextSplitter()
async def index_document(self, document: Document) -> list[DocumentChunk]:
chunks_text = self.splitter.split(document.content)
chunks_text = self.splitter.split(document.content)
chunks: list[DocumentChunk] = []
for idx, text in enumerate(chunks_text):
chunks.append(
@ -42,9 +42,18 @@ class RAGService:
)
)
embeddings = self.embedding_service.embed_texts([c.content for c in chunks])
EMBEDDING_BATCH_SIZE = 50
all_embeddings: list[list[float]] = []
for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE):
batch_chunks = chunks[i:i + EMBEDDING_BATCH_SIZE]
batch_texts = [c.content for c in batch_chunks]
batch_embeddings = self.embedding_service.embed_texts(batch_texts)
all_embeddings.extend(batch_embeddings)
print(f"Created {len(all_embeddings)} embeddings, upserting to Qdrant...")
await self.vector_repository.upsert_chunks(
chunks, embeddings, model_version=self.embedding_service.model_version()
chunks, all_embeddings, model_version=self.embedding_service.model_version()
)
return chunks

View File

@ -39,5 +39,10 @@ class TextSplitter:
def _split_sentences(self, text: str) -> Iterable[str]:
parts = re.split(r"(?<=[\.\?\!])\s+", text)
if len(parts) == 1 and len(text) > self.chunk_size * 2:
chunk_text = []
for i in range(0, len(text), self.chunk_size):
chunk_text.append(text[i:i + self.chunk_size])
return chunk_text
return [p.strip() for p in parts if p.strip()]

View File

@ -36,6 +36,8 @@ class QdrantVectorRepository(IVectorRepository):
embeddings: Sequence[list[float]],
model_version: str,
) -> None:
BATCH_SIZE = 100
points = []
for chunk, vector in zip(chunks, embeddings):
points.append(
@ -52,7 +54,13 @@ class QdrantVectorRepository(IVectorRepository):
},
)
)
self.client.upsert(collection_name=self.collection_name, points=points)
if len(points) >= BATCH_SIZE:
self.client.upsert(collection_name=self.collection_name, points=points)
points = []
if points:
self.client.upsert(collection_name=self.collection_name, points=points)
async def search(
self,

View File

@ -84,7 +84,6 @@ async def process_premium_question(message: Message, user: User, question_text:
try:
from urllib.parse import unquote
decoded = unquote(title)
# Если декодирование изменило строку или исходная содержит %XX
if decoded != title or '%' in title:
title = decoded
except:
@ -152,7 +151,6 @@ async def process_free_question(message: Message, user: User, question_text: str
try:
from urllib.parse import unquote
decoded = unquote(title)
# Если декодирование изменило строку или исходная содержит %XX
if decoded != title or '%' in title:
title = decoded
except: