diff --git a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlNullabilityProcessor.cs b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlNullabilityProcessor.cs index ce62d6755b1..a44d8389410 100644 --- a/src/EFCore.SqlServer/Query/Internal/SqlServerSqlNullabilityProcessor.cs +++ b/src/EFCore.SqlServer/Query/Internal/SqlServerSqlNullabilityProcessor.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using System.Collections; @@ -32,11 +32,11 @@ public class SqlServerSqlNullabilityProcessor( public const string OpenJsonParameterTableName = "__openjson"; private readonly ISqlServerSingletonOptions _sqlServerSingletonOptions = sqlServerSingletonOptions; + private readonly ISqlExpressionFactory _sqlExpressionFactory = dependencies.SqlExpressionFactory; private int _openJsonAliasCounter; private int _totalParameterCount; - /// /// This is an internal API that supports the Entity Framework Core infrastructure and not subject to /// the same compatibility standards as public APIs. It may be changed or removed without notice in @@ -258,6 +258,7 @@ protected override SqlExpression VisitIn(InExpression inExpression, bool allowOp { Check.DebugAssert(valuesParameter.TypeMapping is not null); Check.DebugAssert(valuesParameter.TypeMapping.ElementTypeMapping is not null); + var elementTypeMapping = (RelationalTypeMapping)valuesParameter.TypeMapping.ElementTypeMapping; if (TryHandleOverLimitParameters( @@ -268,30 +269,42 @@ protected override SqlExpression VisitIn(InExpression inExpression, bool allowOp out var constants, out var containsNulls)) { - inExpression = (openJson, constants) switch + if (openJson != null) + { + var column = new ColumnExpression( + "value", + openJson.Alias, + valuesParameter.Type.GetSequenceType().UnwrapNullableType(), + elementTypeMapping, + containsNulls!.Value); + + var subquery = SelectExpression.CreateImmutable( + null!, + [openJson], + [new ProjectionExpression(column, "value")], + null!); + + nullable = false; + + var translatedIn = inExpression.Update(inExpression.Item, subquery); + + if (containsNulls.GetValueOrDefault()) + { + return _sqlExpressionFactory.OrElse( + translatedIn, + _sqlExpressionFactory.IsNull(inExpression.Item)); + } + + return translatedIn; + } + + if (constants != null) { - (not null, null) - => inExpression.Update( - inExpression.Item, - SelectExpression.CreateImmutable( - null!, - [openJson], - [ - new ProjectionExpression( - new ColumnExpression( - "value", - openJson.Alias, - valuesParameter.Type.GetSequenceType(), - elementTypeMapping, - containsNulls!.Value), - "value") - ], - null!)), - - (null, not null) => inExpression.Update(inExpression.Item, constants), - - _ => throw new UnreachableException(), - }; + nullable = false; + return inExpression.Update(inExpression.Item, constants); + } + + throw new UnreachableException("TryHandleOverLimitParameters should return either openJson or constants."); } return base.VisitIn(inExpression, allowOptimizedExpansion, out nullable); diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs index 8824c821189..e6b76455bf1 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/PrimitiveCollectionsQuerySqlServerTest.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. namespace Microsoft.EntityFrameworkCore.Query; @@ -2568,6 +2568,26 @@ SELECT COUNT(*) """); } + [ConditionalFact] + public virtual async Task Parameter_collection_with_null_value_Contains_null_element_Real_Check() + { + using var context = Fixture.CreateContext(); + + var values = Enumerable.Range(1, 2200).Select(i => (int?)i).ToList(); + values.Add(null); + + var queryResults = await context.Set() + .Where(e => values.Contains(e.NullableInt)) + .ToListAsync(); + + var sql = Fixture.TestSqlLoggerFactory.SqlStatements.Last(); + + Assert.NotEmpty(queryResults); + Assert.Contains(queryResults, e => e.NullableInt == null); + Assert.Contains("OPENJSON", sql); + Assert.Contains("IS NULL", sql); + } + [ConditionalFact] public virtual async Task Parameter_collection_of_ints_Contains_int_2071_values() {