Skip to content

Commit f511c96

Browse files
Python: inversed the filter logic for InMemory vector stores (#13457)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Fixes some issues with the forbidden list approach for the in-memory vector store filtering. ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone 😄
1 parent c97c3e6 commit f511c96

File tree

2 files changed

+130
-46
lines changed

2 files changed

+130
-46
lines changed

python/semantic_kernel/connectors/in_memory.py

Lines changed: 130 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,80 @@ class InMemoryCollection(
9292
supported_key_types: ClassVar[set[str] | None] = {"str", "int", "float"}
9393
supported_search_types: ClassVar[set[SearchType]] = {SearchType.VECTOR}
9494

95+
# Allowlist of AST node types permitted in filter expressions.
96+
# This can be overridden in subclasses to extend or restrict allowed operations.
97+
allowed_filter_ast_nodes: ClassVar[set[type]] = {
98+
ast.Expression,
99+
ast.Lambda,
100+
ast.arguments,
101+
ast.arg,
102+
# Comparisons and boolean operations
103+
ast.Compare,
104+
ast.BoolOp,
105+
ast.UnaryOp,
106+
ast.And,
107+
ast.Or,
108+
ast.Not,
109+
ast.Eq,
110+
ast.NotEq,
111+
ast.Lt,
112+
ast.LtE,
113+
ast.Gt,
114+
ast.GtE,
115+
ast.In,
116+
ast.NotIn,
117+
ast.Is,
118+
ast.IsNot,
119+
# Data access
120+
ast.Name,
121+
ast.Load,
122+
ast.Attribute,
123+
ast.Subscript,
124+
ast.Index, # For Python 3.8 compatibility
125+
ast.Slice,
126+
# Literals
127+
ast.Constant,
128+
ast.List,
129+
ast.Tuple,
130+
ast.Set,
131+
ast.Dict,
132+
# Basic arithmetic (useful for computed comparisons)
133+
ast.BinOp,
134+
ast.Add,
135+
ast.Sub,
136+
ast.Mult,
137+
ast.Div,
138+
ast.Mod,
139+
ast.FloorDiv,
140+
# Function calls (restricted to safe builtins separately)
141+
ast.Call,
142+
}
143+
144+
# Allowlist of function/method names that can be called in filter expressions.
145+
allowed_filter_functions: ClassVar[set[str]] = {
146+
"len",
147+
"str",
148+
"int",
149+
"float",
150+
"bool",
151+
"abs",
152+
"min",
153+
"max",
154+
"sum",
155+
"any",
156+
"all",
157+
"lower",
158+
"upper",
159+
"strip",
160+
"startswith",
161+
"endswith",
162+
"contains",
163+
"get",
164+
"keys",
165+
"values",
166+
"items",
167+
}
168+
95169
def __init__(
96170
self,
97171
record_type: type[TModel],
@@ -100,7 +174,17 @@ def __init__(
100174
embedding_generator: EmbeddingGeneratorBase | None = None,
101175
**kwargs: Any,
102176
):
103-
"""Create a In Memory Collection."""
177+
"""Create a In Memory Collection.
178+
179+
In Memory collections are ephemeral and exist only in memory.
180+
They do not persist data to disk or any external storage.
181+
182+
> [Important]
183+
> Filters are powerful things, so make sure to not allow untrusted input here.
184+
> Filters for this collection are parsed and evaluated using Python's `ast` module, so code might be executed.
185+
> We only allow certain AST nodes and functions to be used in the filter expressions to mitigate security risks.
186+
187+
"""
104188
super().__init__(
105189
record_type=record_type,
106190
definition=definition,
@@ -243,39 +327,67 @@ def _get_filtered_records(self, options: VectorSearchOptions) -> dict[TKey, Attr
243327
return filtered_records
244328

245329
def _parse_and_validate_filter(self, filter_str: str) -> Callable:
246-
"""Parse and validate a string filter as a lambda expression, then return the callable."""
247-
forbidden_names = {
248-
"__import__",
249-
"open",
250-
"eval",
251-
"exec",
252-
"__builtins__",
253-
"__class__",
254-
"__bases__",
255-
"__subclasses__",
256-
}
330+
"""Parse and validate a string filter as a lambda expression, then return the callable.
331+
332+
Uses an allowlist approach - only explicitly permitted AST node types and function names
333+
are allowed. This can be customized by overriding `allowed_filter_ast_nodes` and
334+
`allowed_filter_functions` class attributes.
335+
"""
257336
try:
258337
tree = ast.parse(filter_str, mode="eval")
259338
except SyntaxError as e:
260339
raise VectorStoreOperationException(f"Filter string is not valid Python: {e}") from e
261-
# Only allow lambda expressions
340+
341+
# Only allow lambda expressions at the top level
262342
if not (isinstance(tree, ast.Expression) and isinstance(tree.body, ast.Lambda)):
263343
raise VectorStoreOperationException(
264344
"Filter string must be a lambda expression, e.g. 'lambda x: x.key == 1'"
265345
)
266-
# Walk the AST to look for forbidden names and attribute access
346+
347+
# Get the lambda parameter name(s) to allow them as valid Name nodes
348+
lambda_node = tree.body
349+
lambda_param_names = {arg.arg for arg in lambda_node.args.args}
350+
351+
# Walk the AST to validate all nodes against the allowlist
267352
for node in ast.walk(tree):
268-
if isinstance(node, ast.Name) and node.id in forbidden_names:
269-
raise VectorStoreOperationException(f"Use of '{node.id}' is not allowed in filter expressions.")
270-
if isinstance(node, ast.Attribute) and node.attr in forbidden_names:
271-
raise VectorStoreOperationException(f"Use of '{node.attr}' is not allowed in filter expressions.")
353+
node_type = type(node)
354+
355+
# Check if the node type is allowed
356+
if node_type not in self.allowed_filter_ast_nodes:
357+
raise VectorStoreOperationException(
358+
f"AST node type '{node_type.__name__}' is not allowed in filter expressions."
359+
)
360+
361+
# For Name nodes, only allow the lambda parameter
362+
if isinstance(node, ast.Name) and node.id not in lambda_param_names:
363+
raise VectorStoreOperationException(
364+
f"Use of name '{node.id}' is not allowed in filter expressions. "
365+
f"Only the lambda parameter(s) ({', '.join(lambda_param_names)}) can be used."
366+
)
367+
368+
# For Call nodes, validate that only allowed functions are called
369+
if isinstance(node, ast.Call):
370+
func_name = None
371+
if isinstance(node.func, ast.Name):
372+
func_name = node.func.id
373+
elif isinstance(node.func, ast.Attribute):
374+
func_name = node.func.attr
375+
376+
if func_name and func_name not in self.allowed_filter_functions:
377+
raise VectorStoreOperationException(
378+
f"Function '{func_name}' is not allowed in filter expressions. "
379+
f"Allowed functions: {', '.join(sorted(self.allowed_filter_functions))}"
380+
)
381+
272382
try:
273383
code = compile(tree, filename="<filter>", mode="eval")
274384
func = eval(code, {"__builtins__": {}}, {}) # nosec
275385
except Exception as e:
276386
raise VectorStoreOperationException(f"Error compiling filter: {e}") from e
387+
277388
if not callable(func):
278389
raise VectorStoreOperationException("Compiled filter is not callable.")
390+
279391
return func
280392

281393
def _run_filter(self, filter: Callable, record: AttributeDict[TAKey, TAValue]) -> bool:
Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,7 @@
11
# Copyright (c) Microsoft. All rights reserved.
22

3-
from unittest.mock import AsyncMock, Mock
43

5-
import pytest
6-
7-
from semantic_kernel.connectors.ai.chat_completion_client_base import ChatCompletionClientBase
8-
from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings
9-
from semantic_kernel.contents.chat_message_content import ChatMessageContent
104
from semantic_kernel.core_plugins.conversation_summary_plugin import ConversationSummaryPlugin
11-
from semantic_kernel.functions.kernel_arguments import KernelArguments
12-
from semantic_kernel.kernel import Kernel
135
from semantic_kernel.prompt_template.prompt_template_config import PromptTemplateConfig
146

157

@@ -25,23 +17,3 @@ def test_conversation_summary_plugin_with_deprecated_value(kernel):
2517
plugin = ConversationSummaryPlugin(config, kernel=kernel)
2618
assert plugin._summarizeConversationFunction is not None
2719
assert plugin.return_key == "summary"
28-
29-
30-
@pytest.mark.asyncio
31-
async def test_summarize_conversation(kernel: Kernel):
32-
service = AsyncMock(spec=ChatCompletionClientBase)
33-
service.service_id = "default"
34-
service.get_chat_message_contents = AsyncMock(
35-
return_value=[ChatMessageContent(role="assistant", content="Hello World!")]
36-
)
37-
service.get_prompt_execution_settings_class = Mock(return_value=PromptExecutionSettings)
38-
kernel.add_service(service)
39-
config = PromptTemplateConfig(
40-
name="test", description="test", execution_settings={"default": PromptExecutionSettings()}
41-
)
42-
kernel.add_plugin(ConversationSummaryPlugin(config), "summarizer")
43-
args = KernelArguments(input="Hello World!")
44-
45-
await kernel.invoke(plugin_name="summarizer", function_name="SummarizeConversation", arguments=args)
46-
args["summary"] == "Hello world"
47-
service.get_chat_message_contents.assert_called_once()

0 commit comments

Comments
 (0)