forked from HSE_team/BetterCallPraskovia
Забыл батчить, теперь ок
This commit is contained in:
parent
683f779c31
commit
42fcc0eb16
@ -29,7 +29,7 @@ class RAGService:
|
||||
self.splitter = splitter or TextSplitter()
|
||||
|
||||
async def index_document(self, document: Document) -> list[DocumentChunk]:
|
||||
chunks_text = self.splitter.split(document.content)
|
||||
chunks_text = self.splitter.split(document.content)
|
||||
chunks: list[DocumentChunk] = []
|
||||
for idx, text in enumerate(chunks_text):
|
||||
chunks.append(
|
||||
@ -42,9 +42,18 @@ class RAGService:
|
||||
)
|
||||
)
|
||||
|
||||
embeddings = self.embedding_service.embed_texts([c.content for c in chunks])
|
||||
EMBEDDING_BATCH_SIZE = 50
|
||||
all_embeddings: list[list[float]] = []
|
||||
|
||||
for i in range(0, len(chunks), EMBEDDING_BATCH_SIZE):
|
||||
batch_chunks = chunks[i:i + EMBEDDING_BATCH_SIZE]
|
||||
batch_texts = [c.content for c in batch_chunks]
|
||||
batch_embeddings = self.embedding_service.embed_texts(batch_texts)
|
||||
all_embeddings.extend(batch_embeddings)
|
||||
|
||||
print(f"Created {len(all_embeddings)} embeddings, upserting to Qdrant...")
|
||||
await self.vector_repository.upsert_chunks(
|
||||
chunks, embeddings, model_version=self.embedding_service.model_version()
|
||||
chunks, all_embeddings, model_version=self.embedding_service.model_version()
|
||||
)
|
||||
return chunks
|
||||
|
||||
|
||||
@ -39,5 +39,10 @@ class TextSplitter:
|
||||
|
||||
def _split_sentences(self, text: str) -> Iterable[str]:
|
||||
parts = re.split(r"(?<=[\.\?\!])\s+", text)
|
||||
if len(parts) == 1 and len(text) > self.chunk_size * 2:
|
||||
chunk_text = []
|
||||
for i in range(0, len(text), self.chunk_size):
|
||||
chunk_text.append(text[i:i + self.chunk_size])
|
||||
return chunk_text
|
||||
return [p.strip() for p in parts if p.strip()]
|
||||
|
||||
|
||||
@ -36,6 +36,8 @@ class QdrantVectorRepository(IVectorRepository):
|
||||
embeddings: Sequence[list[float]],
|
||||
model_version: str,
|
||||
) -> None:
|
||||
BATCH_SIZE = 100
|
||||
|
||||
points = []
|
||||
for chunk, vector in zip(chunks, embeddings):
|
||||
points.append(
|
||||
@ -52,7 +54,13 @@ class QdrantVectorRepository(IVectorRepository):
|
||||
},
|
||||
)
|
||||
)
|
||||
self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
|
||||
if len(points) >= BATCH_SIZE:
|
||||
self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
points = []
|
||||
|
||||
if points:
|
||||
self.client.upsert(collection_name=self.collection_name, points=points)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
|
||||
@ -84,7 +84,6 @@ async def process_premium_question(message: Message, user: User, question_text:
|
||||
try:
|
||||
from urllib.parse import unquote
|
||||
decoded = unquote(title)
|
||||
# Если декодирование изменило строку или исходная содержит %XX
|
||||
if decoded != title or '%' in title:
|
||||
title = decoded
|
||||
except:
|
||||
@ -152,7 +151,6 @@ async def process_free_question(message: Message, user: User, question_text: str
|
||||
try:
|
||||
from urllib.parse import unquote
|
||||
decoded = unquote(title)
|
||||
# Если декодирование изменило строку или исходная содержит %XX
|
||||
if decoded != title or '%' in title:
|
||||
title = decoded
|
||||
except:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user