Skip to content

Commit 569c8ec

Browse files
committed
refactor: oidc provider configuration
Fixes the issue where discovery urls where assembled incorrectly. Also removes the default values for unbranded providers, since there is no sane default that can be provided.
1 parent 5110b9e commit 569c8ec

File tree

1 file changed

+171
-104
lines changed

1 file changed

+171
-104
lines changed

src/service/oauth/providers.rs

Lines changed: 171 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::collections::BTreeMap;
33
use serde_json::{Map as JsonObject, Value as JsonValue};
44
use tokio::sync::RwLock;
55
pub use tuwunel_core::config::IdentityProvider as Provider;
6-
use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, err, implement};
6+
use tuwunel_core::{Err, Result, debug, debug::INFO_SPAN_LEVEL, implement};
77
use url::Url;
88

99
use crate::SelfServices;
@@ -123,103 +123,154 @@ async fn configure(&self, mut provider: Provider) -> Result<Provider> {
123123
.name
124124
.get_or_insert_with(|| provider.brand.clone());
125125

126+
if provider.brand == "github" {
127+
return configure_github(provider);
128+
}
129+
126130
if provider.issuer_url.is_none() {
127-
_ = provider
128-
.issuer_url
129-
.replace(match provider.brand.as_str() {
130-
| "github" => "https://github.com".try_into()?,
131-
| "gitlab" => "https://gitlab.com".try_into()?,
132-
| "google" => "https://accounts.google.com".try_into()?,
133-
| _ => return Err!(Config("issuer_url", "Required for this provider.")),
134-
});
131+
provider.issuer_url = Some(match provider.brand.as_str() {
132+
| "gitlab" => "https://gitlab.com".try_into()?,
133+
| "google" => "https://accounts.google.com".try_into()?,
134+
| _ => return Err!(Config("issuer_url", "Required for this provider.")),
135+
});
135136
}
136137

137-
if provider.base_path.is_none() {
138-
provider.base_path = match provider.brand.as_str() {
139-
| "github" => Some("/login/oauth".to_owned()),
140-
| _ => None,
141-
};
138+
if !provider.discovery {
139+
assert_manual_urls(&provider)?;
140+
return Ok(provider);
142141
}
143142

144-
let response = self
145-
.discover(&provider)
146-
.await
147-
.and_then(|response| {
148-
response.as_object().cloned().ok_or_else(|| {
149-
err!(Request(NotJson("Expecting JSON object for discovery response")))
150-
})
151-
})
152-
.and_then(|response| check_issuer(response, &provider))?;
143+
let discovery_response = {
144+
let response = self.discover(&provider).await?;
145+
let Some(response_map) = response.as_object() else {
146+
return Err!(Request(NotJson("Expecting JSON object for discovery response")));
147+
};
148+
149+
check_issuer(response_map, &provider)?;
150+
151+
response_map.to_owned()
152+
};
153153

154154
if provider.authorization_url.is_none() {
155-
response
156-
.get("authorization_endpoint")
157-
.and_then(JsonValue::as_str)
158-
.map(Url::parse)
159-
.transpose()?
160-
.or_else(|| make_url(&provider, "/authorize").ok())
161-
.map(|url| provider.authorization_url.replace(url));
155+
provider.authorization_url = Some(assert_and_parse_url(
156+
provider.id(),
157+
&discovery_response,
158+
"authorization_endpoint",
159+
)?);
162160
}
163161

164162
if provider.revocation_url.is_none() {
165-
response
166-
.get("revocation_endpoint")
167-
.and_then(JsonValue::as_str)
168-
.map(Url::parse)
169-
.transpose()?
170-
.or_else(|| make_url(&provider, "/revocation").ok())
171-
.map(|url| provider.revocation_url.replace(url));
163+
provider.revocation_url = Some(assert_and_parse_url(
164+
provider.id(),
165+
&discovery_response,
166+
"revocation_endpoint",
167+
)?);
172168
}
173169

174170
if provider.introspection_url.is_none() {
175-
response
176-
.get("introspection_endpoint")
177-
.and_then(JsonValue::as_str)
178-
.map(Url::parse)
179-
.transpose()?
180-
.or_else(|| make_url(&provider, "/introspection").ok())
181-
.map(|url| provider.introspection_url.replace(url));
171+
provider.introspection_url = Some(assert_and_parse_url(
172+
provider.id(),
173+
&discovery_response,
174+
"introspection_endpoint",
175+
)?);
182176
}
183177

184178
if provider.userinfo_url.is_none() {
185-
response
186-
.get("userinfo_endpoint")
187-
.and_then(JsonValue::as_str)
188-
.map(Url::parse)
189-
.transpose()?
190-
.or_else(|| match provider.brand.as_str() {
191-
| "github" => "https://api.github.com/user".try_into().ok(),
192-
| _ => make_url(&provider, "/userinfo").ok(),
193-
})
194-
.map(|url| provider.userinfo_url.replace(url));
179+
provider.userinfo_url =
180+
Some(assert_and_parse_url(provider.id(), &discovery_response, "userinfo_endpoint")?);
195181
}
196182

197183
if provider.token_url.is_none() {
198-
response
199-
.get("token_endpoint")
200-
.and_then(JsonValue::as_str)
201-
.map(Url::parse)
202-
.transpose()?
203-
.or_else(|| {
204-
let path = if provider.brand == "github" {
205-
"/access_token"
206-
} else {
207-
"/token"
208-
};
209-
210-
make_url(&provider, path).ok()
211-
})
212-
.map(|url| provider.token_url.replace(url));
184+
provider.token_url =
185+
Some(assert_and_parse_url(provider.id(), &discovery_response, "token_endpoint")?);
213186
}
214187

215188
Ok(provider)
216189
}
217190

191+
fn configure_github(mut provider: Provider) -> Result<Provider> {
192+
// See: https://logto.io/oauth-providers-explorer/github
193+
// TODO: find better (first-party?) documentation for these endpoints
194+
195+
provider.base_path = Some("/login/oauth".to_owned());
196+
provider.authorization_url = Some(
197+
"https://api.github.com/authorize"
198+
.try_into()
199+
.expect("valid url"),
200+
);
201+
provider.revocation_url = Some(
202+
"https://api.github.com/revocation"
203+
.try_into()
204+
.expect("valid url"),
205+
);
206+
provider.introspection_url = Some(
207+
"https://api.github.com/introspection"
208+
.try_into()
209+
.expect("valid url"),
210+
);
211+
provider.userinfo_url = Some(
212+
"https://api.github.com/user"
213+
.try_into()
214+
.expect("valid url"),
215+
);
216+
provider.token_url = Some(
217+
"https://api.github.com/access_token"
218+
.try_into()
219+
.expect("valid url"),
220+
);
221+
222+
Ok(provider)
223+
}
224+
225+
fn assert_manual_urls(provider: &Provider) -> Result<()> {
226+
if provider.authorization_url.is_none() {
227+
return Err!(Config(
228+
"authorization_url",
229+
"Required for provider {}, since discovery is disabled",
230+
provider.client_id
231+
));
232+
}
233+
234+
if provider.revocation_url.is_none() {
235+
return Err!(Config(
236+
"revocation_url",
237+
"Required for provider {}, since discovery is disabled",
238+
provider.client_id
239+
));
240+
}
241+
242+
if provider.introspection_url.is_none() {
243+
return Err!(Config(
244+
"introspection_url",
245+
"Required for provider {}, since discovery is disabled",
246+
provider.client_id
247+
));
248+
}
249+
250+
if provider.userinfo_url.is_none() {
251+
return Err!(Config(
252+
"userinfo_url",
253+
"Required for provider {}, since discovery is disabled",
254+
provider.client_id
255+
));
256+
}
257+
258+
if provider.token_url.is_none() {
259+
return Err!(Config(
260+
"token_url",
261+
"Required for provider {}, since discovery is disabled",
262+
provider.client_id
263+
));
264+
}
265+
266+
Ok(())
267+
}
268+
218269
/// Send a network request to a provider at the computed location of the
219270
/// `.well-known/openid-configuration`, returning the configuration.
220271
#[implement(Providers)]
221272
#[tracing::instrument(level = "debug", ret(level = "trace"), skip(self))]
222-
pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
273+
async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
223274
self.services
224275
.client
225276
.oauth
@@ -235,33 +286,51 @@ pub async fn discover(&self, provider: &Provider) -> Result<JsonValue> {
235286
/// Compute the location of the `/.well-known/openid-configuration` based on the
236287
/// local provider config.
237288
fn discovery_url(provider: &Provider) -> Result<Url> {
238-
let default_url = provider
239-
.discovery
240-
.then(|| make_url(provider, "/.well-known/openid-configuration"))
241-
.transpose()?;
242-
243-
let Some(url) = provider
244-
.discovery_url
245-
.clone()
246-
.filter(|_| provider.discovery)
247-
.or(default_url)
248-
else {
289+
if let Some(url) = &provider.discovery_url {
290+
return Ok(url.to_owned());
291+
}
292+
293+
let issuer = provider
294+
.issuer_url
295+
.as_ref()
296+
.expect("issuer to be asserted before calling discover");
297+
298+
let issuer_path = issuer.path();
299+
300+
let ressource_path = ".well-known/openid-configuration";
301+
302+
let base_url = if issuer_path.ends_with('/') {
303+
issuer.to_owned()
304+
} else {
305+
let mut url = issuer.to_owned();
306+
url.set_path((issuer_path.to_owned() + "/").as_str());
307+
url
308+
};
309+
310+
let Some(base_path) = provider.base_path.as_ref() else {
311+
return Ok(base_url.join(ressource_path)?);
312+
};
313+
314+
if base_path.is_empty() {
249315
return Err!(Config(
250-
"discovery_url",
251-
"Failed to determine URL for discovery of provider {}",
316+
"base_path",
317+
"Provider '{}' has an empty base_path. Remove the key or add a value",
252318
provider.id()
253319
));
320+
}
321+
322+
let base_path = if base_path.ends_with('/') {
323+
base_path.to_owned()
324+
} else {
325+
base_path.to_owned() + "/"
254326
};
255327

256-
Ok(url)
328+
Ok(base_url.join((base_path + ressource_path).as_ref())?)
257329
}
258330

259331
/// Validate that the locally configured `issuer_url` matches the issuer claimed
260332
/// in any response. todo: cryptographic validation is not yet implemented here.
261-
fn check_issuer(
262-
response: JsonObject<String, JsonValue>,
263-
provider: &Provider,
264-
) -> Result<JsonObject<String, JsonValue>> {
333+
fn check_issuer(response: &JsonObject<String, JsonValue>, provider: &Provider) -> Result<()> {
265334
let expected = provider
266335
.issuer_url
267336
.as_ref()
@@ -279,23 +348,21 @@ fn check_issuer(
279348
)));
280349
}
281350

282-
Ok(response)
351+
Ok(())
283352
}
284353

285-
/// Generate a full URL for a request to the idp based on the idp's derived
286-
/// configuration.
287-
fn make_url(provider: &Provider, path: &str) -> Result<Url> {
288-
let mut suffix = provider.base_path.clone().unwrap_or_default();
289-
290-
suffix.push_str(path);
291-
let url = provider
292-
.issuer_url
293-
.as_ref()
294-
.ok_or_else(|| {
295-
let id = &provider.client_id;
296-
err!(Config("issuer_url", "Provider {id:?} required field"))
297-
})?
298-
.join(&suffix)?;
354+
/// Assert that a url exists in the response and parse it
355+
fn assert_and_parse_url(
356+
provider_id: &str,
357+
response_map: &serde_json::Map<String, serde_json::Value>,
358+
url_name: &str,
359+
) -> Result<Url> {
360+
let Some(url_value) = response_map.get(url_name) else {
361+
return Err!(
362+
"Error building oidc provider '{provider_id}': {url_name} is missing from openid \
363+
discovery response",
364+
);
365+
};
299366

300-
Ok(url)
367+
Ok(url_value.as_str().unwrap_or_default().parse()?)
301368
}

0 commit comments

Comments
 (0)