@@ -92,6 +92,80 @@ class InMemoryCollection(
9292 supported_key_types : ClassVar [set [str ] | None ] = {"str" , "int" , "float" }
9393 supported_search_types : ClassVar [set [SearchType ]] = {SearchType .VECTOR }
9494
95+ # Allowlist of AST node types permitted in filter expressions.
96+ # This can be overridden in subclasses to extend or restrict allowed operations.
97+ allowed_filter_ast_nodes : ClassVar [set [type ]] = {
98+ ast .Expression ,
99+ ast .Lambda ,
100+ ast .arguments ,
101+ ast .arg ,
102+ # Comparisons and boolean operations
103+ ast .Compare ,
104+ ast .BoolOp ,
105+ ast .UnaryOp ,
106+ ast .And ,
107+ ast .Or ,
108+ ast .Not ,
109+ ast .Eq ,
110+ ast .NotEq ,
111+ ast .Lt ,
112+ ast .LtE ,
113+ ast .Gt ,
114+ ast .GtE ,
115+ ast .In ,
116+ ast .NotIn ,
117+ ast .Is ,
118+ ast .IsNot ,
119+ # Data access
120+ ast .Name ,
121+ ast .Load ,
122+ ast .Attribute ,
123+ ast .Subscript ,
124+ ast .Index , # For Python 3.8 compatibility
125+ ast .Slice ,
126+ # Literals
127+ ast .Constant ,
128+ ast .List ,
129+ ast .Tuple ,
130+ ast .Set ,
131+ ast .Dict ,
132+ # Basic arithmetic (useful for computed comparisons)
133+ ast .BinOp ,
134+ ast .Add ,
135+ ast .Sub ,
136+ ast .Mult ,
137+ ast .Div ,
138+ ast .Mod ,
139+ ast .FloorDiv ,
140+ # Function calls (restricted to safe builtins separately)
141+ ast .Call ,
142+ }
143+
144+ # Allowlist of function/method names that can be called in filter expressions.
145+ allowed_filter_functions : ClassVar [set [str ]] = {
146+ "len" ,
147+ "str" ,
148+ "int" ,
149+ "float" ,
150+ "bool" ,
151+ "abs" ,
152+ "min" ,
153+ "max" ,
154+ "sum" ,
155+ "any" ,
156+ "all" ,
157+ "lower" ,
158+ "upper" ,
159+ "strip" ,
160+ "startswith" ,
161+ "endswith" ,
162+ "contains" ,
163+ "get" ,
164+ "keys" ,
165+ "values" ,
166+ "items" ,
167+ }
168+
95169 def __init__ (
96170 self ,
97171 record_type : type [TModel ],
@@ -100,7 +174,17 @@ def __init__(
100174 embedding_generator : EmbeddingGeneratorBase | None = None ,
101175 ** kwargs : Any ,
102176 ):
103- """Create a In Memory Collection."""
177+ """Create a In Memory Collection.
178+
179+ In Memory collections are ephemeral and exist only in memory.
180+ They do not persist data to disk or any external storage.
181+
182+ > [Important]
183+ > Filters are powerful things, so make sure to not allow untrusted input here.
184+ > Filters for this collection are parsed and evaluated using Python's `ast` module, so code might be executed.
185+ > We only allow certain AST nodes and functions to be used in the filter expressions to mitigate security risks.
186+
187+ """
104188 super ().__init__ (
105189 record_type = record_type ,
106190 definition = definition ,
@@ -243,39 +327,67 @@ def _get_filtered_records(self, options: VectorSearchOptions) -> dict[TKey, Attr
243327 return filtered_records
244328
245329 def _parse_and_validate_filter (self , filter_str : str ) -> Callable :
246- """Parse and validate a string filter as a lambda expression, then return the callable."""
247- forbidden_names = {
248- "__import__" ,
249- "open" ,
250- "eval" ,
251- "exec" ,
252- "__builtins__" ,
253- "__class__" ,
254- "__bases__" ,
255- "__subclasses__" ,
256- }
330+ """Parse and validate a string filter as a lambda expression, then return the callable.
331+
332+ Uses an allowlist approach - only explicitly permitted AST node types and function names
333+ are allowed. This can be customized by overriding `allowed_filter_ast_nodes` and
334+ `allowed_filter_functions` class attributes.
335+ """
257336 try :
258337 tree = ast .parse (filter_str , mode = "eval" )
259338 except SyntaxError as e :
260339 raise VectorStoreOperationException (f"Filter string is not valid Python: { e } " ) from e
261- # Only allow lambda expressions
340+
341+ # Only allow lambda expressions at the top level
262342 if not (isinstance (tree , ast .Expression ) and isinstance (tree .body , ast .Lambda )):
263343 raise VectorStoreOperationException (
264344 "Filter string must be a lambda expression, e.g. 'lambda x: x.key == 1'"
265345 )
266- # Walk the AST to look for forbidden names and attribute access
346+
347+ # Get the lambda parameter name(s) to allow them as valid Name nodes
348+ lambda_node = tree .body
349+ lambda_param_names = {arg .arg for arg in lambda_node .args .args }
350+
351+ # Walk the AST to validate all nodes against the allowlist
267352 for node in ast .walk (tree ):
268- if isinstance (node , ast .Name ) and node .id in forbidden_names :
269- raise VectorStoreOperationException (f"Use of '{ node .id } ' is not allowed in filter expressions." )
270- if isinstance (node , ast .Attribute ) and node .attr in forbidden_names :
271- raise VectorStoreOperationException (f"Use of '{ node .attr } ' is not allowed in filter expressions." )
353+ node_type = type (node )
354+
355+ # Check if the node type is allowed
356+ if node_type not in self .allowed_filter_ast_nodes :
357+ raise VectorStoreOperationException (
358+ f"AST node type '{ node_type .__name__ } ' is not allowed in filter expressions."
359+ )
360+
361+ # For Name nodes, only allow the lambda parameter
362+ if isinstance (node , ast .Name ) and node .id not in lambda_param_names :
363+ raise VectorStoreOperationException (
364+ f"Use of name '{ node .id } ' is not allowed in filter expressions. "
365+ f"Only the lambda parameter(s) ({ ', ' .join (lambda_param_names )} ) can be used."
366+ )
367+
368+ # For Call nodes, validate that only allowed functions are called
369+ if isinstance (node , ast .Call ):
370+ func_name = None
371+ if isinstance (node .func , ast .Name ):
372+ func_name = node .func .id
373+ elif isinstance (node .func , ast .Attribute ):
374+ func_name = node .func .attr
375+
376+ if func_name and func_name not in self .allowed_filter_functions :
377+ raise VectorStoreOperationException (
378+ f"Function '{ func_name } ' is not allowed in filter expressions. "
379+ f"Allowed functions: { ', ' .join (sorted (self .allowed_filter_functions ))} "
380+ )
381+
272382 try :
273383 code = compile (tree , filename = "<filter>" , mode = "eval" )
274384 func = eval (code , {"__builtins__" : {}}, {}) # nosec
275385 except Exception as e :
276386 raise VectorStoreOperationException (f"Error compiling filter: { e } " ) from e
387+
277388 if not callable (func ):
278389 raise VectorStoreOperationException ("Compiled filter is not callable." )
390+
279391 return func
280392
281393 def _run_filter (self , filter : Callable , record : AttributeDict [TAKey , TAValue ]) -> bool :
0 commit comments