一个简单的“大海捞针”分析来测试长上下文llm的上下文检索能力。
在这里插入图片描述
The Test
-
- Place a random fact or statement (the ' needle ') in the middle of a long context window (the ' haystack ')
-
- Ask the model to retrieve this statement
-
- Iterate over various document depths (where the needle is placed) and context lengths to measure performance
LLMNeedleHaystackTester parameters:
- •
model\_to\_test
- The model to run the needle in a haystack test on. Default is None.
- •
evaluator
- An evaluator to evaluate the model's response. Default is None.
- •
needle
- The statement or fact which will be placed in your context ('haystack')
- •
haystack\_dir
- The directory which contains the text files to load as background context. Only text files are supported
- •
retrieval\_question
- The question with which to retrieve your needle in the background context
- •
results\_version
- You may want to run your test multiple times for the same combination of length/depth, change the version number if so
- •
num\_concurrent\_requests
- Default: 1. Set higher if you'd like to run more requests in parallel. Keep in mind rate limits.
- •
save\_results
- Whether or not you'd like to save your results to file. They will be temporarily saved in the object regardless. True/False. If
save\_results = True, then this script will populate aresult/directory with evaluation information. Due to potential concurrent requests each new test will be saved as a few file.
- •
save\_contexts
- Whether or not you'd like to save your contexts to file. Warning these will get very long. True/False
- •
final\_context\_length\_buffer
- The amount of context to take off each input to account for system messages and output tokens. This can be more intelligent but using a static value for now. Default 200 tokens.
- •
context\_lengths\_min
- The starting point of your context lengths list to iterate
- •
context\_lengths\_max
- The ending point of your context lengths list to iterate
- •
context\_lengths\_num\_intervals
- The number of intervals between your min/max to iterate through
- •
context\_lengths
- A custom set of context lengths. This will override the values set for
context\_lengths\_min, max, and intervals if set
- •
document\_depth\_percent\_min
- The starting point of your document depths. Should be int > 0
- •
document\_depth\_percent\_max
- The ending point of your document depths. Should be int < 100
- •
document\_depth\_percent\_intervals
- The number of iterations to do between your min/max points
- •
document\_depth\_percents
- A custom set of document depths lengths. This will override the values set for
document\_depth\_percent\_min, max, and intervals if set
- •
document\_depth\_percent\_interval\_type
- Determines the distribution of depths to iterate over. 'linear' or 'sigmoid
- •
seconds\_to\_sleep\_between\_completions
- Default: None, set # of seconds if you'd like to slow down your requests
- •
print\_ongoing\_status
- Default: True, whether or not to print the status of test as they complete
LLMMultiNeedleHaystackTester parameters:
- •
multi\_needle
- True or False, whether to run multi-needle
- •
needles
- List of needles to insert in the context
Other Parameters:
- •
model\_name
- The name of the model you'd like to use. Should match the exact value which needs to be passed to the api. Ex: For OpenAI inference and evaluator models it would be
gpt-3.5-turbo-0125.
Multi Needle Evaluator
To enable multi-needle insertion into our context, use --multi\_needle True.
This inserts the first needle at the specified depth\_percent, then evenly distributes subsequent needles through the remaining context after this depth.
For even spacing, it calculates the depth\_percent\_interval as:
depth\_percent\_interval = (100 - depth\_percent) / len(self.needles)
So, the first needle is placed at a depth percent of depth\_percent, the second at depth\_percent + depth\_percent\_interval, the third at depth\_percent + 2 * depth\_percent\_interval, and so on.
Following example shows the depth percents for the case of 10 needles and depth_percent of 40%.
depth\_percent\_interval = (100 - 40) / 10 = 6
Needle 1: 40
Needle 2: 40 + 6 = 46
Needle 3: 40 + 2 * 6 = 52
Needle 4: 40 + 3 * 6 = 58
Needle 5: 40 + 4 * 6 = 64
Needle 6: 40 + 5 * 6 = 70
Needle 7: 40 + 6 * 6 = 76
Needle 8: 40 + 7 * 6 = 82
Needle 9: 40 + 8 * 6 = 88
Needle 10: 40 + 9 * 6 = 94
run.py
@dataclass
classCommandArgs():
...
needle: Optional[str] = "\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n"
haystack\_dir: Optional[str] = "PaulGrahamEssays"
retrieval\_question: Optional[str] = "What is the best thing to do in San Francisco?"
results\_version: Optional[int] = 1
context\_lengths\_min: Optional[int] = 1000
context\_lengths\_max: Optional[int] = 16000
context\_lengths\_num\_intervals: Optional[int] = 35
context\_lengths: Optional[list[int]] = None
document\_depth\_percent\_min: Optional[int] = 0
document\_depth\_percent\_max: Optional[int] = 100
document\_depth\_percent\_intervals: Optional[int] = 35
document\_depth\_percents: Optional[list[int]] = None
document\_depth\_percent\_interval\_type: Optional[str] = "linear"
...
save\_results: Optional[bool] = True
save\_contexts: Optional[bool] = True
final\_context\_length\_buffer: Optional[int] = 200
...
# Multi-needle parameters
multi\_needle: Optional[bool] = False
needles: list[str] = field(default\_factory=lambda: [
" Figs are one of the secret ingredients needed to build the perfect pizza. ",
" Prosciutto is one of the secret ingredients needed to build the perfect pizza. ",
" Goat cheese is one of the secret ingredients needed to build the perfect pizza. "
])
默认needle为:
"\nThe best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.\n"
haystack存放在
PaulGrahamEssays
目录中:
在这里插入图片描述
检索的query是:
What is the best thing to do in San Francisco?
多needles是:
[ " Figs are one of the secret ingredients needed to build the perfect pizza. ", " Prosciutto is one of the secret ingredients needed to build the perfect pizza. ", " Goat cheese is one of the secret ingredients needed to build the perfect pizza. " ]
函数入口为:
def main():
args = CLI(CommandArgs, as\_positional=False)
args.model\_to\_test = get\_model\_to\_test(args)
args.evaluator = get\_evaluator(args)
if args.multi\_needle == True:
print("Testing multi-needle")
tester = LLMMultiNeedleHaystackTester(**args.\_\_dict\_\_)
else:
print("Testing single-needle")
tester = LLMNeedleHaystackTester(**args.\_\_dict\_\_)
tester.start\_test()
providers
实现不同LLM的调用。
from abc import ABC, abstractmethod
from typing importOptional
classModelProvider(ABC):
@abstractmethod
asyncdefevaluate\_model(self, prompt: str) -> str: ...
@abstractmethod
defgenerate\_prompt(self, context: str, retrieval\_question: str) -> str | list[dict[str, str]]: ...
@abstractmethod
defencode\_text\_to\_tokens(self, text: str) -> list[int]: ...
@abstractmethod
defdecode\_tokens(self, tokens: list[int], context\_length: Optional[int] = None) -> str: ...
import os
from operator import itemgetter
from typing importOptional
from openai import AsyncOpenAI
from langchain\_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
import tiktoken
from .model import ModelProvider
classOpenAI(ModelProvider):
"""
A wrapper class for interacting with OpenAI's API, providing methods to encode text, generate prompts,
evaluate models, and create LangChain runnables for language model interactions.
Attributes:
model\_name (str): The name of the OpenAI model to use for evaluations and interactions.
model (AsyncOpenAI): An instance of the AsyncOpenAI client for asynchronous API calls.
tokenizer: A tokenizer instance for encoding and decoding text to and from token representations.
"""
DEFAULT\_MODEL\_KWARGS: dict = dict(max\_tokens = 300,
temperature = 0)
def\_\_init\_\_(self,
model\_name: str = "gpt-3.5-turbo-0125",
model\_kwargs: dict = DEFAULT\_MODEL\_KWARGS):
"""
Initializes the OpenAI model provider with a specific model.
Args:
model\_name (str): The name of the OpenAI model to use. Defaults to 'gpt-3.5-turbo-0125'.
model\_kwargs (dict): Model configuration. Defaults to {max\_tokens: 300, temperature: 0}.
Raises:
ValueError: If NIAH\_MODEL\_API\_KEY is not found in the environment.
"""
api\_key = os.getenv('NIAH\_MODEL\_API\_KEY')
if (not api\_key):
raise ValueError("NIAH\_MODEL\_API\_KEY must be in env.")
self.model\_name = model\_name
self.model\_kwargs = model\_kwargs
self.api\_key = api\_key
self.model = AsyncOpenAI(api\_key=self.api\_key)
self.tokenizer = tiktoken.encoding\_for\_model(self.model\_name)
asyncdefevaluate\_model(self, prompt: str) -> str:
"""
Evaluates a given prompt using the OpenAI model and retrieves the model's response.
Args:
prompt (str): The prompt to send to the model.
Returns:
str: The content of the model's response to the prompt.
"""
response = awaitself.model.chat.completions.create(
model=self.model\_name,
messages=prompt,
**self.model\_kwargs
)
return response.choices[0].message.content
defgenerate\_prompt(self, context: str, retrieval\_question: str) -> str | list[dict[str, str]]:
"""
Generates a structured prompt for querying the model, based on a given context and retrieval question.
Args:
context (str): The context or background information relevant to the question.
retrieval\_question (str): The specific question to be answered by the model.
Returns:
list[dict[str, str]]: A list of dictionaries representing the structured prompt, including roles and content for system and user messages.
"""
return [{
"role": "system",
"content": "You are a helpful AI bot that answers questions for a user. Keep your response short and direct"
},
{
"role": "user",
"content": context
},
{
"role": "user",
"content": f"{retrieval\_question} Don't give information outside the document or repeat your findings"
}]
defencode\_text\_to\_tokens(self, text: str) -> list[int]:
"""
Encodes a given text string to a sequence of tokens using the model's tokenizer.
Args:
text (str): The text to encode.
Returns:
list[int]: A list of token IDs representing the encoded text.
"""
returnself.tokenizer.encode(text)
defdecode\_tokens(self, tokens: list[int], context\_length: Optional[int] = None) -> str:
"""
Decodes a sequence of tokens back into a text string using the model's tokenizer.
Args:
tokens (list[int]): The sequence of token IDs to decode.
context\_length (Optional[int], optional): An optional length specifying the number of tokens to decode. If not provided, decodes all tokens.
Returns:
str: The decoded text string.
"""
returnself.tokenizer.decode(tokens[:context\_length])
defget\_langchain\_runnable(self, context: str) -> str:
"""
Creates a LangChain runnable that constructs a prompt based on a given context and a question,
queries the OpenAI model, and returns the model's response. This method leverages the LangChain
library to build a sequence of operations: extracting input variables, generating a prompt,
querying the model, and processing the response.
Args:
context (str): The context or background information relevant to the user's question.
This context is provided to the model to aid in generating relevant and accurate responses.
Returns:
str: A LangChain runnable object that can be executed to obtain the model's response to a
dynamically provided question. The runnable encapsulates the entire process from prompt
generation to response retrieval.
Example:
To use the runnable:
- Define the context and question.
- Execute the runnable with these parameters to get the model's response.
"""
template = """You are a helpful AI bot that answers questions for a user. Keep your response short and direct" \n
\n ------- \n
{context}
\n ------- \n
Here is the user question: \n --- --- --- \n {question} \n Don't give information outside the document or repeat your findings."""
prompt = PromptTemplate(
template=template,
input\_variables=["context", "question"],
)
# Create a LangChain runnable
model = ChatOpenAI(temperature=0, model=self.model\_name)
chain = ( {"context": lambda x: context,
"question": itemgetter("question")}
| prompt
| model
)
return chain
evaluators
评估是否正确找到needle,使用LLM来评估,评估标准如下:
CRITERIA = {"accuracy": """
Score 1: The answer is completely unrelated to the reference.
Score 3: The answer has minor relevance but does not align with the reference.
Score 5: The answer has moderate relevance but contains inaccuracies.
Score 7: The answer aligns with the reference but has minor omissions.
Score 10: The answer is completely accurate and aligns perfectly with the reference.
Only respond with a numberical score"""}
from abc import ABC, abstractmethod
class Evaluator(ABC):
CRITERIA: dict[str, str]
@abstractmethod
def evaluate\_response(self, response: str) -> int: ...
import os
from .evaluator import Evaluator
from langchain.evaluation import load\_evaluator
from langchain\_community.chat\_models import ChatOpenAI
classOpenAIEvaluator(Evaluator):
DEFAULT\_MODEL\_KWARGS: dict = dict(temperature=0)
CRITERIA = {"accuracy": """
Score 1: The answer is completely unrelated to the reference.
Score 3: The answer has minor relevance but does not align with the reference.
Score 5: The answer has moderate relevance but contains inaccuracies.
Score 7: The answer aligns with the reference but has minor omissions.
Score 10: The answer is completely accurate and aligns perfectly with the reference.
Only respond with a numberical score"""}
def\_\_init\_\_(self,
model\_name: str = "gpt-3.5-turbo-0125",
model\_kwargs: dict = DEFAULT\_MODEL\_KWARGS,
true\_answer: str = None,
question\_asked: str = None,):
"""
:param model\_name: The name of the model.
:param model\_kwargs: Model configuration. Default is {temperature: 0}
:param true\_answer: The true answer to the question asked.
:param question\_asked: The question asked to the model.
"""
if (not true\_answer) or (not question\_asked):
raise ValueError("true\_answer and question\_asked must be supplied with init.")
self.model\_name = model\_name
self.model\_kwargs = model\_kwargs
self.true\_answer = true\_answer
self.question\_asked = question\_asked
api\_key = os.getenv('NIAH\_EVALUATOR\_API\_KEY')
if (not api\_key):
raise ValueError("NIAH\_EVALUATOR\_API\_KEY must be in env for using openai evaluator.")
self.api\_key = api\_key
self.evaluator = ChatOpenAI(model=self.model\_name,
openai\_api\_key=self.api\_key,
**self.model\_kwargs)
defevaluate\_response(self, response: str) -> int:
evaluator = load\_evaluator(
"labeled\_score\_string",
criteria=self.CRITERIA,
llm=self.evaluator,
)
eval\_result = evaluator.evaluate\_strings(
# The models response
prediction=response,
# The actual answer
reference=self.true\_answer,
# The question asked
input=self.question\_asked,
)
returnint(eval\_result['score'])
单needle测试
主要逻辑在generate\_context和run\_test中,完整的看代码吧。
import asyncio
import glob
import json
import os
import time
import numpy as np
from .evaluators import Evaluator
from .providers import ModelProvider
from asyncio import Semaphore
from datetime import datetime, timezone
classLLMNeedleHaystackTester:
"""
This class is used to test the LLM Needle Haystack.
"""
def\_\_init\_\_(self,
model\_to\_test: ModelProvider = None,
evaluator: Evaluator = None,
needle = None,
haystack\_dir = "PaulGrahamEssays",
retrieval\_question = None,
results\_version = 1,
context\_lengths\_min = 1000,
context\_lengths\_max = 16000,
context\_lengths\_num\_intervals = 35,
context\_lengths = None,
document\_depth\_percent\_min = 0,
document\_depth\_percent\_max = 100,
document\_depth\_percent\_intervals = 35,
document\_depth\_percents = None,
document\_depth\_percent\_interval\_type = "linear",
num\_concurrent\_requests = 1,
save\_results = True,
save\_contexts = True,
final\_context\_length\_buffer = 200,
seconds\_to\_sleep\_between\_completions = None,
print\_ongoing\_status = True,
**kwargs):
"""
:model\_to\_test: The model to test. Default is None.
:evaluator: An evaluator to evaluate the model's response. Default is None.
:param needle: The needle to be found in the haystack. Default is None.
:param haystack\_dir: The directory of text files to use as background context (or a haystack) in which the needle is to be found. Default is Paul Graham Essays.
:param retrieval\_question: The question which with to prompt the model to do the retrieval.
:param results\_version: In case you would like to try the same combination of model, context length, and depth % multiple times, change the results version other than 1
:param num\_concurrent\_requests: Due to volume, this object is set up to run concurrent requests, default = 1. Be careful of rate limits.
:param save\_results: Whether or not you would like to save your contexts to file. Warning: These will get long! Default = True
:param save\_contexts: Whether or not you would like to save your contexts to file. Warning: These will get long! Default is True.
:param final\_context\_length\_buffer: The amount of cushion you'd like to leave off the input context to allow for the output context. Default 200 tokens
:param context\_lengths\_min: The minimum length of the context. Default is 1000.
:param context\_lengths\_max: The maximum length of the context. Default is 200000.
:param context\_lengths\_num\_intervals: The number of intervals for the context length. Default is 35.
:param context\_lengths: The lengths of the context. Default is None.
:param document\_depth\_percent\_min: The minimum depth percent of the document. Default is 0.
:param document\_depth\_percent\_max: The maximum depth percent of the document. Default is 100.
:param document\_depth\_percent\_intervals: The number of intervals for the document depth percent. Default is 35.
:param document\_depth\_percents: The depth percentages of the document. Default is None.
:param document\_depth\_percent\_interval\_type: The type of interval for the document depth percent. Must be either 'linear' or 'sigmoid'. Default is 'linear'.
:param seconds\_to\_sleep\_between\_completions: The number of seconds to sleep between completions. Default is None.
:param print\_ongoing\_status: Whether or not to print the ongoing status. Default is True.
:param kwargs: Additional arguments.
"""
ifnot model\_to\_test:
raise ValueError("A language model must be provided to test.")
ifnot needle ornot haystack\_dir ornot retrieval\_question:
raise ValueError("Needle, haystack, and retrieval\_question must be provided.")
self.needle = needle
self.haystack\_dir = haystack\_dir
self.retrieval\_question = retrieval\_question
self.results\_version = results\_version
self.num\_concurrent\_requests = num\_concurrent\_requests
self.save\_results = save\_results
self.final\_context\_length\_buffer = final\_context\_length\_buffer
self.save\_contexts = save\_contexts
self.seconds\_to\_sleep\_between\_completions = seconds\_to\_sleep\_between\_completions
self.print\_ongoing\_status = print\_ongoing\_status
self.testing\_results = []
if context\_lengths isNone:
if context\_lengths\_min isNoneor context\_lengths\_max isNoneor context\_lengths\_num\_intervals isNone:
raise ValueError("Either context\_lengths\_min, context\_lengths\_max, context\_lengths\_intervals need to be filled out OR the context\_lengths\_list needs to be supplied.")
else:
# context\_lengths\_num\_intervals -> num of context\_lengths\_num\_intervals
self.context\_lengths = np.round(np.linspace(context\_lengths\_min, context\_lengths\_max, num=context\_lengths\_num\_intervals, endpoint=True)).astype(int)
else:
self.context\_lengths = context\_lengths
if document\_depth\_percent\_interval\_type notin [None, "linear", "sigmoid"]:
raise ValueError("document\_depth\_percent\_interval\_type must be either None, 'linear' or 'sigmoid'. If you'd like your own distribution give a list of ints in via document\_depth\_percent\_intervals")
if document\_depth\_percents isNone:
if document\_depth\_percent\_min isNoneor document\_depth\_percent\_max isNoneor document\_depth\_percent\_intervals isNone:
raise ValueError("Either document\_depth\_percent\_min, document\_depth\_percent\_max, document\_depth\_percent\_intervals need to be filled out OR the document\_depth\_percents needs to be supplied.")
if document\_depth\_percent\_interval\_type == 'linear':
# document\_depth\_percent\_intervals -> num of document\_depth\_percent\_intervals
self.document\_depth\_percents = np.round(np.linspace(document\_depth\_percent\_min, document\_depth\_percent\_max, num=document\_depth\_percent\_intervals, endpoint=True)).astype(int)
elif document\_depth\_percent\_interval\_type == 'sigmoid':
self.document\_depth\_percents = [self.logistic(x) for x in np.linspace(document\_depth\_percent\_min, document\_depth\_percent\_max, document\_depth\_percent\_intervals)]
else:
raise ValueError("document\_depth\_percent\_interval\_type must be either 'sigmoid' or 'linear' if document\_depth\_percents is None.")
else:
self.document\_depth\_percents = document\_depth\_percents
self.model\_to\_test = model\_to\_test
self.model\_name = self.model\_to\_test.model\_name
self.evaluation\_model = evaluator
deflogistic(self, x, L=100, x0=50, k=.1):
if x in [0, 100]:
return x
x = -k * (x - x0)
return np.round(L * self.sigmoid(x), 3)
defsigmoid(self, x):
return1 / (1 + np.exp(-x))
asyncdefbound\_evaluate\_and\_log(self, sem, *args):
asyncwith sem:
awaitself.evaluate\_and\_log(*args)
asyncdefrun\_test(self):
sem = Semaphore(self.num\_concurrent\_requests)
# Run through each iteration of context\_lengths and depths
tasks = []
for context\_length inself.context\_lengths:
for depth\_percent inself.document\_depth\_percents:
task = self.bound\_evaluate\_and\_log(sem, context\_length, depth\_percent)
tasks.append(task)
# Wait for all tasks to complete
await asyncio.gather(*tasks)
asyncdefevaluate\_and\_log(self, context\_length, depth\_percent):
# Checks to see if you've already checked a length/percent/version.
# This helps if the program stop running and you want to restart later
ifself.save\_results:
ifself.result\_exists(context\_length, depth\_percent):
return
# Go generate the required length context and place your needle statement in
context = awaitself.generate\_context(context\_length, depth\_percent)
# Prepare your message to send to the model you're going to evaluate
prompt = self.model\_to\_test.generate\_prompt(context, self.retrieval\_question)
test\_start\_time = time.time()
# Go see if the model can answer the question to pull out your random fact
response = awaitself.model\_to\_test.evaluate\_model(prompt)
test\_end\_time = time.time()
test\_elapsed\_time = test\_end\_time - test\_start\_time
# Compare the reponse to the actual needle you placed
score = self.evaluation\_model.evaluate\_response(response)
results = {
# 'context' : context, # Uncomment this line if you'd like to save the context the model was asked to retrieve from. Warning: This will become very large.
'model' : self.model\_name,
'context\_length' : int(context\_length),
'depth\_percent' : float(depth\_percent),
'version' : self.results\_version,
'needle' : self.needle,
'model\_response' : response,
'score' : score,
'test\_duration\_seconds' : test\_elapsed\_time,
'test\_timestamp\_utc' : datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S%z')
}
self.testing\_results.append(results)
ifself.print\_ongoing\_status:
print (f"-- Test Summary -- ")
print (f"Duration: {test\_elapsed\_time:.1f} seconds")
print (f"Context: {context\_length} tokens")
print (f"Depth: {depth\_percent}%")
print (f"Score: {score}")
print (f"Response: {response}\n")
context\_file\_location = f'{self.model\_name.replace(".", "\_")}\_len\_{context\_length}\_depth\_{int(depth\_percent*100)}'
ifself.save\_contexts:
results['file\_name'] = context\_file\_location
# Save the context to file for retesting
ifnot os.path.exists('contexts'):
os.makedirs('contexts')
withopen(f'contexts/{context\_file\_location}\_context.txt', 'w') as f:
f.write(context)
ifself.save\_results:
# Save the context to file for retesting
ifnot os.path.exists('results'):
os.makedirs('results')
# Save the result to file for retesting
withopen(f'results/{context\_file\_location}\_results.json', 'w') as f:
json.dump(results, f)
ifself.seconds\_to\_sleep\_between\_completions:
await asyncio.sleep(self.seconds\_to\_sleep\_between\_completions)
defresult\_exists(self, context\_length, depth\_percent):
"""
Checks to see if a result has already been evaluated or not
"""
results\_dir = 'results/'
ifnot os.path.exists(results\_dir):
returnFalse
for filename in os.listdir(results\_dir):
if filename.endswith('.json'):
withopen(os.path.join(results\_dir, filename), 'r') as f:
result = json.load(f)
context\_length\_met = result['context\_length'] == context\_length
depth\_percent\_met = result['depth\_percent'] == depth\_percent
version\_met = result.get('version', 1) == self.results\_version
model\_met = result['model'] == self.model\_name
if context\_length\_met and depth\_percent\_met and version\_met and model\_met:
returnTrue
returnFalse
asyncdefgenerate\_context(self, context\_length, depth\_percent):
# Load up tiktoken so we navigate tokens more easily
# Get your haystack dir files loaded into a string
context = self.read\_context\_files()
# Truncate the haystack dir essays to the context length you desire
context = self.encode\_and\_trim(context, context\_length)
# Insert your random statement according to your depth percent
context = self.insert\_needle(context, depth\_percent, context\_length)
return context
definsert\_needle(self, context, depth\_percent, context\_length):
tokens\_needle = self.model\_to\_test.encode\_text\_to\_tokens(self.needle)
tokens\_context = self.model\_to\_test.encode\_text\_to\_tokens(context)
# Reducing the context length by 150 buffer. This is to account for system message, the user question, and response.
context\_length -= self.final\_context\_length\_buffer
# If your context + needle are longer than the context length (which it will be), then reduce tokens from the context by the needle length
iflen(tokens\_context) + len(tokens\_needle) > context\_length:
tokens\_context = tokens\_context[:context\_length - len(tokens\_needle)]
if depth\_percent == 100:
# If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end
tokens\_new\_context = tokens\_context + tokens\_needle
else:
# Go get the position (in terms of tokens) to insert your needle
insertion\_point = int(len(tokens\_context) * (depth\_percent / 100))
# tokens\_new\_context represents the tokens before the needle
tokens\_new\_context = tokens\_context[:insertion\_point]
# We want to make sure that we place our needle at a sentence break so we first see what token a '.' is
period\_tokens = self.model\_to\_test.encode\_text\_to\_tokens('.')
# Then we iteration backwards until we find the first period
while tokens\_new\_context and tokens\_new\_context[-1] notin period\_tokens:
insertion\_point -= 1
tokens\_new\_context = tokens\_context[:insertion\_point]
# Once we get there, then add in your needle, and stick the rest of your context in on the other end.
# Now we have a needle in a haystack
tokens\_new\_context += tokens\_needle + tokens\_context[insertion\_point:]
# Convert back to a string and return it
new\_context = self.model\_to\_test.decode\_tokens(tokens\_new\_context)
return new\_context
defget\_context\_length\_in\_tokens(self, context):
returnlen(self.model\_to\_test.encode\_text\_to\_tokens(context))
defread\_context\_files(self):
context = ""
max\_context\_length = max(self.context\_lengths)
base\_dir = os.path.abspath(os.path.dirname(\_\_file\_\_)) # Package directory
whileself.get\_context\_length\_in\_tokens(context) < max\_context\_length:
for file in glob.glob(os.path.join(base\_dir, self.haystack\_dir, "*.txt")):
withopen(file, 'r') as f:
context += f.read()
return context
defencode\_and\_trim(self, context, context\_length):
tokens = self.model\_to\_test.encode\_text\_to\_tokens(context)
iflen(tokens) > context\_length:
context = self.model\_to\_test.decode\_tokens(tokens, context\_length)
return context
defget\_results(self):
returnself.testing\_results
defprint\_start\_test\_summary(self):
print ("\n")
print ("Starting Needle In A Haystack Testing...")
print (f"- Model: {self.model\_name}")
print (f"- Context Lengths: {len(self.context\_lengths)}, Min: {min(self.context\_lengths)}, Max: {max(self.context\_lengths)}")
print (f"- Document Depths: {len(self.document\_depth\_percents)}, Min: {min(self.document\_depth\_percents)}%, Max: {max(self.document\_depth\_percents)}%")
print (f"- Needle: {self.needle.strip()}")
print ("\n\n")
defstart\_test(self):
ifself.print\_ongoing\_status:
self.print\_start\_test\_summary()
asyncio.run(self.run\_test())
多needle测试
继承自LLMNeedleHaystackTester,关注insert\_needles和evaluate\_and\_log实现。
import asyncio
import glob
import json
import os
import time
from asyncio import Semaphore
from datetime import datetime, timezone
import numpy as np
from .evaluators import Evaluator
from .llm\_needle\_haystack\_tester import LLMNeedleHaystackTester
from .providers import ModelProvider
classLLMMultiNeedleHaystackTester(LLMNeedleHaystackTester):
"""
Extends LLMNeedleHaystackTester to support testing with multiple needles in the haystack.
Attributes:
needles (list): A list of needles (facts) to insert into the haystack (context).
model\_to\_test (ModelProvider): The model being tested.
evaluator (Evaluator): The evaluator used to assess the model's performance.
print\_ongoing\_status (bool): Flag to print ongoing status messages.
eval\_set (str): The evaluation set identifier.
"""
def\_\_init\_\_(self, *args,
needles=[],
model\_to\_test: ModelProvider = None,
evaluator: Evaluator = None,
print\_ongoing\_status = True,
eval\_set = "multi-needle-eval-sf",
**kwargs):
super().\_\_init\_\_(*args, model\_to\_test=model\_to\_test, **kwargs)
self.needles = needles
self.evaluator = evaluator
self.model\_to\_test = model\_to\_test
self.eval\_set = eval\_set
self.model\_name = self.model\_to\_test.model\_name
self.print\_ongoing\_status = print\_ongoing\_status
self.insertion\_percentages = []
asyncdefinsert\_needles(self, context, depth\_percent, context\_length):
"""
Inserts multiple needles (specific facts or pieces of information) into the original context string at
designated depth percentages, effectively distributing these needles throughout the context. This method
is designed to test a model's ability to retrieve specific information (needles) from a larger body of text
(haystack) based on the placement depth of these needles.
The method first encodes the context and each needle into tokens to calculate their lengths in tokens.
It then adjusts the context length to accommodate the final buffer length. This is crucial for ensuring
that the total token count (context plus needles) does not exceed the maximum allowable context length,
which might otherwise lead to information being truncated.
This approach calculates the initial insertion point for the first needle as before but then calculates even
spacing for the remaining needles based on the remaining context length. It ensures that needles are
distributed as evenly as possible throughout the context after the first insertion.
Args:
context (str): The original context string.
depth\_percent (float): The depth percent at which to insert the needles.
context\_length (int): The total length of the context in tokens, adjusted for final buffer.
Returns:
str: The new context with needles inserted.
"""
tokens\_context = self.model\_to\_test.encode\_text\_to\_tokens(context)
context\_length -= self.final\_context\_length\_buffer
# Calculate the total length of all needles in tokens
total\_needles\_length = sum(len(self.model\_to\_test.encode\_text\_to\_tokens(needle)) for needle inself.needles)
# Ensure context length accounts for needles
iflen(tokens\_context) + total\_needles\_length > context\_length:
tokens\_context = tokens\_context[:context\_length - total\_needles\_length]
# To evenly distribute the needles, we calculate the intervals they need to be inserted.
depth\_percent\_interval = (100 - depth\_percent) / len(self.needles)
# Reset the insertion percentages list for the current context
self.insertion\_percentages = []
# Insert needles at calculated points
for needle inself.needles:
tokens\_needle = self.model\_to\_test.encode\_text\_to\_tokens(needle)
if depth\_percent == 100:
# If your depth percent is 100 (which means your needle is the last thing in the doc), throw it at the end
tokens\_context = tokens\_context + tokens\_needle
else:
# Go get the position (in terms of tokens) to insert your needle
insertion\_point = int(len(tokens\_context) * (depth\_percent / 100))
# tokens\_new\_context represents the tokens before the needle
tokens\_new\_context = tokens\_context[:insertion\_point]
# We want to make sure that we place our needle at a sentence break so we first see what token a '.' is
period\_tokens = self.model\_to\_test.encode\_text\_to\_tokens('.')
# Then we iteration backwards until we find the first period
while tokens\_new\_context and tokens\_new\_context[-1] notin period\_tokens:
insertion\_point -= 1
tokens\_new\_context = tokens\_context[:insertion\_point]
# Insert the needle into the context at the found position
tokens\_context = tokens\_context[:insertion\_point] + tokens\_needle + tokens\_context[insertion\_point:]
# Log
insertion\_percentage = (insertion\_point / len(tokens\_context)) * 100
self.insertion\_percentages.append(insertion\_percentage)
print(f"Inserted '{needle}' at {insertion\_percentage:.2f}% of the context, total length now: {len(tokens\_context)} tokens")
# Adjust depth for next needle
depth\_percent += depth\_percent\_interval
new\_context = self.model\_to\_test.decode\_tokens(tokens\_context)
return new\_context
defencode\_and\_trim(self, context, context\_length):
"""
Encodes the context to tokens and trims it to the specified length.
Args:
context (str): The context to encode and trim.
context\_length (int): The desired length of the context in tokens.
Returns:
str: The encoded and trimmed context.
"""
tokens = self.model\_to\_test.encode\_text\_to\_tokens(context)
iflen(tokens) > context\_length:
context = self.model\_to\_test.decode\_tokens(tokens, context\_length)
return context
asyncdefgenerate\_context(self, context\_length, depth\_percent):
"""
Generates a context of a specified length and inserts needles at given depth percentages.
Args:
context\_length (int): The total length of the context in tokens.
depth\_percent (float): The depth percent for needle insertion.
Returns:
str: The context with needles inserted.
"""
context = self.read\_context\_files()
context = self.encode\_and\_trim(context, context\_length)
context = awaitself.insert\_needles(context, depth\_percent, context\_length)
return context
asyncdefevaluate\_and\_log(self, context\_length, depth\_percent):
"""
Evaluates the model's performance with the generated context and logs the results.
Args:
context\_length (int): The length of the context in tokens.
depth\_percent (float): The depth percent for needle insertion.
"""
ifself.save\_results:
ifself.result\_exists(context\_length, depth\_percent):
return
# Go generate the required length context and place your needle statement in
context = awaitself.generate\_context(context\_length, depth\_percent)
test\_start\_time = time.time()
# LangSmith
## TODO: Support for other evaluators
ifself.evaluator.\_\_class\_\_.\_\_name\_\_ == "LangSmithEvaluator":
print("EVALUATOR: LANGSMITH")
chain = self.model\_to\_test.get\_langchain\_runnable(context)
self.evaluator.evaluate\_chain(chain, context\_length, depth\_percent, self.model\_to\_test.model\_name, self.eval\_set, len(self.needles), self.needles, self.insertion\_percentages)
test\_end\_time = time.time()
test\_elapsed\_time = test\_end\_time - test\_start\_time
else:
print("EVALUATOR: OpenAI Model")
# Prepare your message to send to the model you're going to evaluate
prompt = self.model\_to\_test.generate\_prompt(context, self.retrieval\_question)
# Go see if the model can answer the question to pull out your random fact
response = awaitself.model\_to\_test.evaluate\_model(prompt)
# Compare the reponse to the actual needle you placed
score = self.evaluation\_model.evaluate\_response(response)
test\_end\_time = time.time()
test\_elapsed\_time = test\_end\_time - test\_start\_time
results = {
# 'context' : context, # Uncomment this line if you'd like to save the context the model was asked to retrieve from. Warning: This will become very large.
'model' : self.model\_to\_test.model\_name,
'context\_length' : int(context\_length),
'depth\_percent' : float(depth\_percent),
'version' : self.results\_version,
'needle' : self.needle,
'model\_response' : response,
'score' : score,
'test\_duration\_seconds' : test\_elapsed\_time,
'test\_timestamp\_utc' : datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S%z')
}
self.testing\_results.append(results)
ifself.print\_ongoing\_status:
print (f"-- Test Summary -- ")
print (f"Duration: {test\_elapsed\_time:.1f} seconds")
print (f"Context: {context\_length} tokens")
print (f"Depth: {depth\_percent}%")
print (f"Score: {score}")
print (f"Response: {response}\n")
context\_file\_location = f'{self.model\_name.replace(".", "\_")}\_len\_{context\_length}\_depth\_{int(depth\_percent*100)}'
ifself.save\_contexts:
results['file\_name'] = context\_file\_location
# Save the context to file for retesting
ifnot os.path.exists('contexts'):
os.makedirs('contexts')
withopen(f'contexts/{context\_file\_location}\_context.txt', 'w') as f:
f.write(context)
ifself.save\_results:
# Save the context to file for retesting
ifnot os.path.exists('results'):
os.makedirs('results')
# Save the result to file for retesting
withopen(f'results/{context\_file\_location}\_results.json', 'w') as f:
json.dump(results, f)
ifself.seconds\_to\_sleep\_between\_completions:
await asyncio.sleep(self.seconds\_to\_sleep\_between\_completions)
asyncdefbound\_evaluate\_and\_log(self, sem, *args):
asyncwith sem:
awaitself.evaluate\_and\_log(*args)
asyncdefrun\_test(self):
sem = Semaphore(self.num\_concurrent\_requests)
# Run through each iteration of context\_lengths and depths
tasks = []
for context\_length inself.context\_lengths:
for depth\_percent inself.document\_depth\_percents:
task = self.bound\_evaluate\_and\_log(sem, context\_length, depth\_percent)
tasks.append(task)
# Wait for all tasks to complete
await asyncio.gather(*tasks)
defprint\_start\_test\_summary(self):
print ("\n")
print ("Starting Needle In A Haystack Testing...")
print (f"- Model: {self.model\_name}")
print (f"- Context Lengths: {len(self.context\_lengths)}, Min: {min(self.context\_lengths)}, Max: {max(self.context\_lengths)}")
print (f"- Document Depths: {len(self.document\_depth\_percents)}, Min: {min(self.document\_depth\_percents)}%, Max: {max(self.document\_depth\_percents)}%")
print (f"- Needles: {self.needles}")
print ("\n\n")
defstart\_test(self):
ifself.print\_ongoing\_status:
self.print\_start\_test\_summary()
asyncio.run(self.run\_test())
