Skip to content

Commit 7c5e248

Browse files
committed
feat: add --timeout/-t flag for initialization timeout
Adds a configurable timeout for MCP server initialization: - Default remains 10 seconds (matching current behavior) - Can be set via --timeout or -t flag - Includes tests for timeout flag parsing
1 parent e7bd724 commit 7c5e248

File tree

4 files changed

+86
-11
lines changed

4 files changed

+86
-11
lines changed

cmd/mcptools/commands/call.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@ func parseCallArgs(cmdArgs []string) (string, []string) {
3636
case (cmdArgs[i] == FlagAuthHeader) && i+1 < len(cmdArgs):
3737
AuthHeader = cmdArgs[i+1]
3838
i += 2
39+
case (cmdArgs[i] == FlagTimeout || cmdArgs[i] == FlagTimeoutShort) && i+1 < len(cmdArgs):
40+
if _, err := fmt.Sscanf(cmdArgs[i+1], "%d", &InitTimeout); err != nil {
41+
fmt.Fprintf(os.Stderr, "Warning: invalid timeout value %q, using default\n", cmdArgs[i+1])
42+
}
43+
i += 2
3944
case !entityExtracted:
4045
entityName = cmdArgs[i]
4146
entityExtracted = true

cmd/mcptools/commands/root.go

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@ import (
99

1010
// flags.
1111
const (
12-
FlagFormat = "--format"
13-
FlagFormatShort = "-f"
14-
FlagParams = "--params"
15-
FlagParamsShort = "-p"
16-
FlagHelp = "--help"
17-
FlagHelpShort = "-h"
18-
FlagServerLogs = "--server-logs"
19-
FlagTransport = "--transport"
20-
FlagAuthUser = "--auth-user"
21-
FlagAuthHeader = "--auth-header"
12+
FlagFormat = "--format"
13+
FlagFormatShort = "-f"
14+
FlagParams = "--params"
15+
FlagParamsShort = "-p"
16+
FlagHelp = "--help"
17+
FlagHelpShort = "-h"
18+
FlagServerLogs = "--server-logs"
19+
FlagTransport = "--transport"
20+
FlagAuthUser = "--auth-user"
21+
FlagAuthHeader = "--auth-header"
22+
FlagTimeout = "--timeout"
23+
FlagTimeoutShort = "-t"
2224
)
2325

2426
// entity types.
@@ -51,6 +53,8 @@ var (
5153
AuthUser string
5254
// AuthHeader is a custom Authorization header.
5355
AuthHeader string
56+
// InitTimeout is the timeout for MCP server initialization in seconds.
57+
InitTimeout = 10
5458
)
5559

5660
// RootCmd creates the root command.
@@ -68,6 +72,7 @@ It allows you to discover and call tools, list resources, and interact with MCP-
6872
cmd.PersistentFlags().StringVar(&TransportOption, "transport", "http", "HTTP transport type (http, sse)")
6973
cmd.PersistentFlags().StringVar(&AuthUser, "auth-user", "", "Basic authentication in username:password format")
7074
cmd.PersistentFlags().StringVar(&AuthHeader, "auth-header", "", "Custom Authorization header (e.g., 'Bearer token' or 'Basic base64credentials')")
75+
cmd.PersistentFlags().IntVarP(&InitTimeout, "timeout", "t", 10, "Initialization timeout in seconds")
7176

7277
return cmd
7378
}

cmd/mcptools/commands/utils.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"encoding/json"
88
"fmt"
99
"net/url"
10+
"os"
1011
"strings"
1112
"time"
1213

@@ -176,7 +177,7 @@ var CreateClientFunc = func(args []string, _ ...client.ClientOption) (*client.Cl
176177
if err != nil {
177178
return nil, fmt.Errorf("init error: %w", err)
178179
}
179-
case <-time.After(10 * time.Second):
180+
case <-time.After(time.Duration(InitTimeout) * time.Second):
180181
return nil, fmt.Errorf("initialization timed out")
181182
}
182183

@@ -211,6 +212,11 @@ func ProcessFlags(args []string) []string {
211212
case args[i] == FlagAuthHeader && i+1 < len(args):
212213
AuthHeader = args[i+1]
213214
i += 2
215+
case (args[i] == FlagTimeout || args[i] == FlagTimeoutShort) && i+1 < len(args):
216+
if _, err := fmt.Sscanf(args[i+1], "%d", &InitTimeout); err != nil {
217+
fmt.Fprintf(os.Stderr, "Warning: invalid timeout value %q, using default\n", args[i+1])
218+
}
219+
i += 2
214220
default:
215221
parsedArgs = append(parsedArgs, args[i])
216222
i++

cmd/mcptools/commands/utils_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,62 @@ nested {"key":"value"}`[1:] // remove first newline
190190
})
191191
}
192192
}
193+
194+
func TestProcessFlagsTimeout(t *testing.T) {
195+
originalTimeout := InitTimeout
196+
defer func() { InitTimeout = originalTimeout }()
197+
198+
tests := []struct {
199+
name string
200+
args []string
201+
wantArgs []string
202+
wantTimeout int
203+
}{
204+
{
205+
name: "default timeout",
206+
args: []string{"cmd", "arg1"},
207+
wantArgs: []string{"cmd", "arg1"},
208+
wantTimeout: 10,
209+
},
210+
{
211+
name: "long timeout flag",
212+
args: []string{"cmd", "--timeout", "60", "arg1"},
213+
wantArgs: []string{"cmd", "arg1"},
214+
wantTimeout: 60,
215+
},
216+
{
217+
name: "short timeout flag",
218+
args: []string{"cmd", "-t", "10", "arg1"},
219+
wantArgs: []string{"cmd", "arg1"},
220+
wantTimeout: 10,
221+
},
222+
{
223+
name: "timeout at end",
224+
args: []string{"cmd", "arg1", "--timeout", "120"},
225+
wantArgs: []string{"cmd", "arg1"},
226+
wantTimeout: 120,
227+
},
228+
{
229+
name: "invalid timeout keeps previous",
230+
args: []string{"cmd", "--timeout", "invalid", "arg1"},
231+
wantArgs: []string{"cmd", "arg1"},
232+
wantTimeout: 10,
233+
},
234+
}
235+
236+
for _, tt := range tests {
237+
t.Run(tt.name, func(t *testing.T) {
238+
InitTimeout = 10
239+
240+
gotArgs := ProcessFlags(tt.args)
241+
242+
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
243+
t.Errorf("ProcessFlags() gotArgs = %v, want %v", gotArgs, tt.wantArgs)
244+
}
245+
246+
if InitTimeout != tt.wantTimeout {
247+
t.Errorf("ProcessFlags() InitTimeout = %v, want %v", InitTimeout, tt.wantTimeout)
248+
}
249+
})
250+
}
251+
}

0 commit comments

Comments
 (0)