Skip to content

Commit 723b303

Browse files
committed
C#: Support generic extensions in non-generic extension types.
1 parent 29c3c63 commit 723b303

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

csharp/extractor/Semmle.Extraction.CSharp/CodeAnalysisExtensions/SymbolExtensions.cs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -653,8 +653,7 @@ public static bool TryGetExtensionMethod(this IMethodSymbol method, out IMethodS
653653
.FirstOrDefault(m => SymbolEqualityComparer.Default.Equals(m.AssociatedExtensionImplementation, method.ConstructedFrom));
654654

655655
var isFullyConstructed = method.IsBoundGenericMethod();
656-
// TODO: We also need to handle generic methods in non-generic extension types.
657-
if (isFullyConstructed && unboundDeclaration?.ContainingType is INamedTypeSymbol extensionType && extensionType.IsGenericType)
656+
if (isFullyConstructed && unboundDeclaration?.ContainingType is INamedTypeSymbol extensionType)
658657
{
659658
try
660659
{
@@ -663,7 +662,9 @@ public static bool TryGetExtensionMethod(this IMethodSymbol method, out IMethodS
663662
var (extensionTypeArguments, extensionMethodArguments) = arguments.SplitAt(extensionType.TypeParameters.Length);
664663

665664
// Construct the extension type.
666-
var boundExtensionType = extensionType.Construct(extensionTypeArguments.ToArray());
665+
var boundExtensionType = extensionType.IsUnboundGenericType()
666+
? extensionType.Construct(extensionTypeArguments.ToArray())
667+
: extensionType;
667668

668669
// Find the extension method declaration within the constructed extension type.
669670
var extensionDeclaration = boundExtensionType.GetMembers()
@@ -690,10 +691,22 @@ public static bool TryGetExtensionMethod(this IMethodSymbol method, out IMethodS
690691
return declaration is not null;
691692
}
692693

694+
/// <summary>
695+
/// Returns true if this method is an unbound generic method.
696+
/// </summary>
693697
public static bool IsUnboundGenericMethod(this IMethodSymbol method) =>
694698
method.IsGenericMethod && SymbolEqualityComparer.Default.Equals(method.ConstructedFrom, method);
695699

696-
public static bool IsBoundGenericMethod(this IMethodSymbol method) => method.IsGenericMethod && !IsUnboundGenericMethod(method);
700+
/// <summary>
701+
/// Returns true if this method is a bound generic method.
702+
/// </summary>
703+
public static bool IsBoundGenericMethod(this IMethodSymbol method) => method.IsGenericMethod && !method.IsUnboundGenericMethod();
704+
705+
/// <summary>
706+
/// Returns true if this type is an unbound generic type.
707+
/// </summary>
708+
public static bool IsUnboundGenericType(this INamedTypeSymbol type) =>
709+
type.IsGenericType && SymbolEqualityComparer.Default.Equals(type.ConstructedFrom, type);
697710

698711
/// <summary>
699712
/// Gets the base type of `symbol`. Unlike `symbol.BaseType`, this excludes effective base

0 commit comments

Comments
 (0)