Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -583,9 +583,15 @@ private IReadOnlyList<ValueExpression> GetProtocolMethodArguments(Dictionary<str
bodyModel = ScmCodeModelGenerator.Instance.TypeFactory.CreateModel(model);
}

var bodyParamProvider = ConvenienceMethodParameters.FirstOrDefault(p => p.Location == ParameterLocation.Body);
var nonBodyProperties = bodyModel?.CanonicalView.Properties
.Where(p => p.WireInfo?.IsHttpMetadata == true)
.ToDictionary(p => p.WireInfo!.SerializedName, p => p);

// Create a mapping from convenience parameter names to their ParameterProvider
var convenienceParamsMap = ConvenienceMethodParameters.ToDictionary(p => p.Name, p => p, StringComparer.OrdinalIgnoreCase);

bool requireNamedArgs = false;
// Iterate through protocol parameters to maintain correct argument order
foreach (var protocolParam in ProtocolMethodParameters)
{
Expand All @@ -596,7 +602,7 @@ private IReadOnlyList<ValueExpression> GetProtocolMethodArguments(Dictionary<str
}

// Try to find the corresponding convenience parameter using MethodParameterSegments
if (protocolParam.InputParameter?.MethodParameterSegments is { Count: > 1 })
if (protocolParam.InputParameter?.MethodParameterSegments is { Count: > 1 } && nonBodyProperties?.ContainsKey(protocolParam.Name) != true)
{
// The MethodParameterSegments represents a path (e.g., ['Params', 'foo'] means params.foo)
var rootParameterName = protocolParam.InputParameter.MethodParameterSegments[0].Name;
Expand All @@ -616,7 +622,7 @@ private IReadOnlyList<ValueExpression> GetProtocolMethodArguments(Dictionary<str
if (ScmCodeModelGenerator.Instance.TypeFactory.CSharpTypeMap.TryGetValue(convenienceParam.Type, out var typeProvider) &&
typeProvider is ModelProvider paramModel)
{
conversions.Add(paramModel.GetPropertyExpression(convenienceParam, propertySegments));
AddArgument(protocolParam, paramModel.GetPropertyExpression(convenienceParam, propertySegments));
}
}
else
Expand All @@ -625,12 +631,20 @@ private IReadOnlyList<ValueExpression> GetProtocolMethodArguments(Dictionary<str
{
if (protocolParam.IsContentParameter)
{
convenienceParam = ConvenienceMethodParameters.FirstOrDefault(p => p.Location == ParameterLocation.Body);
convenienceParam = bodyParamProvider;
}
}

if (convenienceParam == null)
{
if (TryGetNonBodyModelPropertyConversion(protocolParam, out var conversion))
{
AddArgument(protocolParam, conversion);
}
else
{
requireNamedArgs = true;
}
continue;
}

Expand All @@ -639,68 +653,41 @@ private IReadOnlyList<ValueExpression> GetProtocolMethodArguments(Dictionary<str
{
if (!addedSpreadSource && declarations.TryGetValue("spread", out ValueExpression? spread))
{
conversions.Add(spread);
AddArgument(protocolParam, spread);
addedSpreadSource = true;
}
}
else if (convenienceParam.Location == ParameterLocation.Body)
{
// Add any non-body parameters that may have been declared within the request body model
List<ValueExpression>? requiredParameters = null;
List<ValueExpression>? optionalParameters = null;

if (convenienceParam.Type.Equals(bodyModel?.Type))
{
var parameterConversions =
GetNonBodyModelPropertiesConversions(convenienceParam, bodyModel);
if (parameterConversions != null)
{
requiredParameters = parameterConversions.Value.RequiredParameters;
optionalParameters = parameterConversions.Value.OptionalParameters;
}
}

// Add required non-body parameters
if (requiredParameters != null)
{
conversions.AddRange(requiredParameters);
}

if (convenienceParam.Type.IsReadOnlyMemory || convenienceParam.Type.IsList)
{
conversions.Add(declarations["content"]);
AddArgument(protocolParam, declarations["content"]);
}
else if (convenienceParam.Type.IsEnum)
{
conversions.Add(RequestContentApiSnippets.Create(
AddArgument(protocolParam, RequestContentApiSnippets.Create(
BinaryDataSnippets.FromObjectAsJson(convenienceParam.Type.ToSerial(convenienceParam))));
}
else if (convenienceParam.Type.Equals(typeof(BinaryData)))
{
conversions.Add(RequestContentApiSnippets.Create(convenienceParam));
AddArgument(protocolParam, RequestContentApiSnippets.Create(convenienceParam));
}
else if (convenienceParam.Type.IsFrameworkType)
{
conversions.Add(declarations["content"]);
AddArgument(protocolParam, declarations["content"]);
}
else
{
conversions.Add(convenienceParam);
}

// Add optional non-body parameters
if (optionalParameters != null)
{
conversions.AddRange(optionalParameters);
AddArgument(protocolParam, convenienceParam);
}
}
else if (convenienceParam.Type.IsEnum)
{
conversions.Add(convenienceParam.Type.ToSerial(convenienceParam));
AddArgument(protocolParam, convenienceParam.Type.ToSerial(convenienceParam));
}
else
{
conversions.Add(convenienceParam);
AddArgument(protocolParam, convenienceParam);
}
}
}
Expand All @@ -709,45 +696,37 @@ private IReadOnlyList<ValueExpression> GetProtocolMethodArguments(Dictionary<str
var requestOptionsApi = ScmCodeModelGenerator.Instance.TypeFactory.HttpRequestOptionsApi;
// Build method name like "ToRequestOptions" or "ToRequestContext" based on the parameter name
var toRequestOptionsMethodName = $"ToRequest{char.ToUpper(requestOptionsApi.ParameterName[0])}{requestOptionsApi.ParameterName.Substring(1)}";
conversions.Add(ScmKnownParameters.CancellationToken.Invoke(toRequestOptionsMethodName, extensionType: _cancellationTokenExtensionsDefinition.Type));
AddArgument(ScmKnownParameters.RequestOptions, ScmKnownParameters.CancellationToken.Invoke(toRequestOptionsMethodName, extensionType: _cancellationTokenExtensionsDefinition.Type));

return conversions;
}

private (List<ValueExpression> RequiredParameters, List<ValueExpression> OptionalParameters)?
GetNonBodyModelPropertiesConversions(ParameterProvider bodyParam, ModelProvider bodyModel)
{
// Extract non-body properties from the body model
var nonBodyProperties = bodyModel.CanonicalView.Properties
.Where(p => p.WireInfo?.IsHttpMetadata == true)
.ToDictionary(p => p.WireInfo!.SerializedName, p => p);
void AddArgument(ParameterProvider protocolParam, ValueExpression argument)
{
conversions.Add(requireNamedArgs ? protocolParam.PositionalReference(argument) : argument);
}

if (nonBodyProperties.Count == 0)
return null;
bool TryGetNonBodyModelPropertyConversion(ParameterProvider protocolParam, out ValueExpression conversion)
{
conversion = Default;
if (bodyParamProvider is null || bodyModel is null || nonBodyProperties is null)
{
return false;
}

List<ValueExpression> required = [];
List<ValueExpression> optional = [];
if (!bodyParamProvider.Type.Equals(bodyModel.Type) || protocolParam.Location == ParameterLocation.Body)
{
return false;
}

// Add properties for matching protocol parameters
foreach (var protocolParameter in ProtocolMethodParameters)
{
if (protocolParameter.Location != ParameterLocation.Body &&
(nonBodyProperties.TryGetValue(protocolParameter.WireInfo.SerializedName, out var nonBodyProperty) ||
nonBodyProperties.TryGetValue(protocolParameter.Name, out nonBodyProperty)))
if (nonBodyProperties.TryGetValue(protocolParam.WireInfo.SerializedName, out var nonBodyProperty) ||
nonBodyProperties.TryGetValue(protocolParam.Name, out nonBodyProperty))
{
var conversion = bodyParam.Property(nonBodyProperty.Name);
if (protocolParameter.DefaultValue != null)
{
optional.Add(conversion);
}
else
{
required.Add(conversion);
}
conversion = bodyParamProvider.Property(nonBodyProperty.Name);
return true;
}

return false;
}

return (required, optional);
return conversions;
}

private ScmMethodProvider BuildProtocolMethod(MethodProvider createRequestMethod, bool isAsync, bool shouldMakeParametersRequired)
Expand Down
Loading