305 lines
9.3 KiB
Python
305 lines
9.3 KiB
Python
# Text Chunking Utilities with heading-aware splitting
|
|
import re
|
|
from typing import List
|
|
|
|
|
|
def estimate_tokens(text: str) -> int:
|
|
"""
|
|
Estimate number of tokens in text.
|
|
|
|
Uses simple approximation: 1 token = 4 characters
|
|
|
|
Args:
|
|
text: The text to estimate
|
|
|
|
Returns:
|
|
Estimated token count as integer
|
|
"""
|
|
return len(text) // 4
|
|
|
|
|
|
def _split_at_headings(text: str) -> List[tuple]:
|
|
"""
|
|
Split text at markdown headings while preserving heading content.
|
|
|
|
Args:
|
|
text: The full text
|
|
|
|
Returns:
|
|
List of (heading_text, remaining_text) tuples or [(text,) if no headings]
|
|
"""
|
|
# Match markdown headings (##, ###, ####, etc.)
|
|
pattern = r'(#{1,6})\s+(.+?)(?=\n#{1,6}|\Z)'
|
|
|
|
parts = []
|
|
remaining = text
|
|
|
|
while True:
|
|
match = re.search(pattern, remaining, re.MULTILINE)
|
|
if not match:
|
|
break
|
|
|
|
heading_start = match.start()
|
|
heading_content = match.group(0).strip()
|
|
|
|
# Insert the heading chunk
|
|
parts.append((heading_content, None))
|
|
remaining = remaining[match.end():]
|
|
|
|
if remaining and not parts:
|
|
return [(text,)]
|
|
|
|
if remaining:
|
|
# Add final non-heading section
|
|
last_h_start = sum(len(h) for _, h in parts)
|
|
parts.append((remaining[last_h_start:], None))
|
|
|
|
if not parts and text:
|
|
parts = [(text,)]
|
|
|
|
return parts
|
|
|
|
|
|
def _split_at_paragraphs(text: str, max_tokens: int) -> List[str]:
|
|
"""
|
|
Split text at paragraph boundaries.
|
|
|
|
Args:
|
|
text: The text to split
|
|
max_tokens: Maximum tokens per chunk
|
|
|
|
Returns:
|
|
List of chunks, each respecting max_tokens
|
|
"""
|
|
# Split by double newlines (paragraphs)
|
|
paragraphs = re.split(r'\n\s*\n', text.strip()) if text else []
|
|
|
|
chunks = []
|
|
current_chunk = ""
|
|
|
|
for para in paragraphs:
|
|
para_with_tokens = estimate_tokens(para) + (1 if current_chunk else 0)
|
|
|
|
if estimate_tokens(current_chunk) + para_with_tokens <= max_tokens:
|
|
if current_chunk:
|
|
current_chunk += "\n\n" + para
|
|
else:
|
|
current_chunk = para
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk)
|
|
|
|
# If paragraph alone is too big, try splitting by sentences
|
|
if estimate_tokens(para) > max_tokens:
|
|
para_chunks = _split_at_sentences(para, max_tokens)
|
|
for pchunk in para_chunks:
|
|
if estimate_tokens(current_chunk) + 1 <= max_tokens:
|
|
current_chunk += "\n\n" + pchunk
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk)
|
|
current_chunk = pchunk
|
|
else:
|
|
current_chunk = para
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk)
|
|
|
|
return chunks
|
|
|
|
|
|
def _split_at_sentences(text: str, max_tokens: int) -> List[str]:
|
|
"""
|
|
Split text at sentence boundaries.
|
|
|
|
Args:
|
|
text: The text to split
|
|
max_tokens: Maximum tokens per chunk
|
|
|
|
Returns:
|
|
List of chunks respecting max_tokens
|
|
"""
|
|
if not text:
|
|
return []
|
|
|
|
# Split on sentence endings but preserve the delimiter
|
|
sentences = re.split(r'([.!?]+)', text)
|
|
|
|
chunks = []
|
|
current_chunk = ""
|
|
token_count = 0
|
|
|
|
for part in sentences:
|
|
part_tokens = estimate_tokens(part) + (1 if current_chunk else 0)
|
|
|
|
if token_count + part_tokens <= max_tokens:
|
|
if current_chunk:
|
|
current_chunk += " " + part
|
|
else:
|
|
current_chunk = part
|
|
token_count = estimate_tokens(current_chunk)
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk)
|
|
|
|
# Try to fit as much of this sentence as possible
|
|
start = 0
|
|
while start < len(part):
|
|
test_chunk = part[start:]
|
|
if estimate_tokens(test_chunk) <= max_tokens and not current_chunk:
|
|
current_chunk = test_chunk
|
|
token_count = estimate_tokens(current_chunk)
|
|
break
|
|
|
|
# Take a smaller piece
|
|
test_size = max_tokens - (token_count + 1) if current_chunk else max_tokens
|
|
if test_size <= 0:
|
|
test_size = 1
|
|
|
|
small_piece = part[start:start + test_size]
|
|
if not current_chunk:
|
|
current_chunk = small_piece
|
|
else:
|
|
chunks.append(current_chunk)
|
|
current_chunk = small_piece
|
|
|
|
token_count = estimate_tokens(current_chunk)
|
|
|
|
if start + test_size >= len(part):
|
|
break
|
|
|
|
start += test_size
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk)
|
|
|
|
return chunks
|
|
|
|
|
|
def chunk_text(text: str, max_tokens: int = 500, overlap_tokens: int = 80) -> List[str]:
|
|
"""
|
|
Chunk text intelligently using heading, paragraph, and sentence boundaries.
|
|
|
|
Prefers splitting on headings, paragraphs, then sentence boundaries.
|
|
Preserves markdown headings in their own chunks.
|
|
Avoids empty chunks and ensures no chunk exceeds max_tokens by too much.
|
|
|
|
Args:
|
|
text: The full text to chunk
|
|
max_tokens: Maximum tokens per chunk (default 500)
|
|
overlap_tokens: Number of overlapping tokens between chunks (default 80)
|
|
|
|
Returns:
|
|
List of chunk strings with preserved markdown headings
|
|
"""
|
|
if text is None:
|
|
raise TypeError("text must be a string")
|
|
|
|
if not text:
|
|
return []
|
|
|
|
if max_tokens <= 0:
|
|
raise ValueError("max_tokens must be greater than 0")
|
|
|
|
max_chars = max(1, max_tokens * 4)
|
|
overlap_chars = min(max(overlap_tokens, 0) * 4, max_chars // 2)
|
|
chunks = []
|
|
clean_text = text.strip()
|
|
|
|
paragraphs = [p.strip() for p in re.split(r"\n\s*\n", clean_text) if p.strip()]
|
|
if 1 < len(paragraphs) and max_tokens <= 20 and all(estimate_tokens(p) <= max_tokens for p in paragraphs):
|
|
return paragraphs
|
|
|
|
start = 0
|
|
|
|
while start < len(clean_text):
|
|
hard_end = min(start + max_chars, len(clean_text))
|
|
if hard_end == len(clean_text):
|
|
final_chunk = clean_text[start:].strip()
|
|
if final_chunk:
|
|
chunks.append(final_chunk)
|
|
break
|
|
|
|
window = clean_text[start:hard_end]
|
|
min_split = max(1, len(window) // 2)
|
|
split_at = None
|
|
|
|
for pattern in (r"\n#{1,6}\s+", r"\n\s*\n", r"(?<=[.!?])\s+", r"\s+"):
|
|
matches = list(re.finditer(pattern, window))
|
|
candidates = [m.start() for m in matches if m.start() >= min_split]
|
|
if candidates:
|
|
split_at = max(candidates)
|
|
break
|
|
|
|
if split_at is None:
|
|
split_at = len(window)
|
|
|
|
end = start + split_at
|
|
chunk = clean_text[start:end].strip()
|
|
if chunk:
|
|
chunks.append(chunk)
|
|
|
|
next_start = end - overlap_chars if overlap_chars else end
|
|
if next_start <= start:
|
|
next_start = end
|
|
start = next_start
|
|
|
|
return [c for c in chunks if c.strip()]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Test estimate_tokens
|
|
test_text_400 = "a" * 400
|
|
assert estimate_tokens(test_text_400) == 100, f"Expected 100 tokens for 400 chars, got {estimate_tokens(test_text_400)}"
|
|
|
|
print(f"estimate_tokens test passed: 400 chars -> {estimate_tokens(test_text_400)} tokens")
|
|
|
|
# Test with empty text
|
|
assert chunk_text("") == [], "Empty text should return empty list"
|
|
print("chunk_text empty test passed")
|
|
|
|
# Test small text (single chunk)
|
|
small = "This is a very short text that should be returned as a single chunk."
|
|
chunks = chunk_text(small)
|
|
assert len(chunks) == 1, f"Short text should be one chunk, got {len(chunks)}"
|
|
assert chunks[0] == small, "Content should match for small text"
|
|
print("chunk_text single chunk test passed")
|
|
|
|
# Test chunking with headings
|
|
markdown_with_headings = """# Introduction
|
|
|
|
This is the introduction section.
|
|
|
|
## Background
|
|
|
|
Background information goes here to make this longer and test chunking.
|
|
|
|
This paragraph has more content about the background topic.
|
|
|
|
### Details
|
|
|
|
Specific details about the background are provided in this subsection.
|
|
|
|
More details follow here to ensure we have enough text to properly test heading preservation.
|
|
|
|
## Conclusion
|
|
|
|
The conclusion wraps up everything nicely."""
|
|
|
|
chunks = chunk_text(markdown_with_headings, max_tokens=50)
|
|
|
|
# Verify headings are preserved
|
|
heading_chunks = [c for c in chunks if c.strip().startswith('#')]
|
|
print(f"\nFound {len(heading_chunks)} heading chunks:")
|
|
for hc in heading_chunks:
|
|
print(f" - {hc.strip()}")
|
|
|
|
assert len(chunks) > 1, f"Should have multiple chunks, got {len(chunks)}"
|
|
|
|
# Verify no chunk exceeds max_tokens by too much
|
|
all_under = all(estimate_tokens(c) <= 50 + 20 for c in chunks) # Allow some tolerance
|
|
assert all_under, "Some chunks exceed token limit significantly"
|
|
print("All chunks respect token limits")
|
|
|
|
print("\nAll tests passed!")
|