Skip to content
Snippets Groups Projects
Forked from an inaccessible project.
oauth.rs 5.57 KiB
/* Copyright 2021 Dominik George <dominik.george@teckids.org>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

use crate::config::{
    get_or_error,
    get_optional
};

use config::Config;

use oauth2::{
    AuthUrl,
    ClientId,
    ClientSecret,
    RequestTokenError,
    ResourceOwnerUsername,
    ResourceOwnerPassword,
    Scope,
    TokenResponse,
    TokenUrl
};
use oauth2::basic::{
    BasicClient,
    BasicTokenResponse
};
use oauth2::reqwest::http_client;

use std::error;

use serde::Deserialize;
use reqwest;

use serde_json;
use jq_rs;

fn full_key(parts: Vec<&str>) -> String {
    parts.join(".")
}

fn get_client<E: Copy>(conf: &Config, prefix: &str, error_value: E) -> Result<BasicClient, E> {
    let client_id = ClientId::new(get_or_error(&conf, &full_key(vec![prefix, "client_id"]), error_value)?);
    let client_secret = match get_optional(&conf, &full_key(vec![prefix, "client_secret"])) {
        Some(v) => Some(ClientSecret::new(v)),
        None => None,
    };
    let auth_url = match AuthUrl::new(get_or_error(&conf, &full_key(vec![prefix, "auth_url"]), error_value)?) {
        Ok(u) => u,
        _ => {
            error!("Could not parse authorization URL");
            return Err(error_value);
        },
    };
    let token_url = match get_optional(&conf, &full_key(vec![prefix, "token_url"])) {
        Some(v) => match TokenUrl::new(v) {
            Ok(u) => Some(u),
            Err(_) => {
                error!("Could not parse token URL");
                return Err(error_value);
            }
        },
        None => None,
    };

    let client = BasicClient::new(client_id, client_secret, auth_url, token_url);
    return Ok(client);
}

pub fn get_access_token_client<E: Copy>(conf: &Config, prefix: &str, error_value: E, unauth_value: E) -> Result<BasicTokenResponse, E> {
    let scopes: Vec<String> = match get_optional(&conf, &full_key(vec![prefix, "scopes"])) {
        Some(v) => v,
        None => vec![]
    };

    let client = get_client(conf, prefix, error_value)?;
    let mut request = client.exchange_client_credentials();
    for scope in scopes {
        request = request.add_scope(Scope::new(scope.to_string()));
    }

    let result = request.request(http_client);
    match result {
            Ok(t) => Ok(t),
            Err(e) => match e {
                RequestTokenError::ServerResponse(t) => {
                    error!("Authorization server returned error: {}", t);
                    return Err(unauth_value);
                },
                _ => {
                    error!("Error fetchin access token: {}", e);
                    return Err(error_value);
                },
            },
        }
}

pub fn get_access_token_password<E: Copy>(conf: &Config, prefix: &str, username: String, password: String, error_value: E, unauth_value: E) -> Result<BasicTokenResponse, E> {
    let scopes: Vec<String> = match get_optional(&conf, &full_key(vec![prefix, "scopes"])) {
        Some(v) => v,
        None => vec![]
    };

    let res_username = ResourceOwnerUsername::new(username);
    let res_password = ResourceOwnerPassword::new(password);

    let client = get_client(conf, prefix, error_value)?;
    let mut request = client.exchange_password(&res_username, &res_password);
    for scope in scopes {
        request = request.add_scope(Scope::new(scope.to_string()));
    }

    let result = request.request(http_client);
    match result {
            Ok(t) => Ok(t),
            Err(e) => match e {
                RequestTokenError::ServerResponse(t) => {
                    error!("Authorization server returned error: {}", t);
                    return Err(unauth_value);
                },
                _ => {
                    error!("Error fetchin access token: {}", e);
                    return Err(error_value);
                },
            },
        }
}

fn get_data(conf: &Config, prefix: &str, endpoint: &str, param: String, token: &BasicTokenResponse) -> Result<String, Box<dyn error::Error>> {
    let access_token = token.access_token().secret();

    let mut endpoint_url: String = get_or_error(&conf, &full_key(vec![prefix, "urls", endpoint]), "")?;
    endpoint_url = endpoint_url.replace("{}", &param);

    debug!("Loading text data from {}", endpoint_url);
    let client = reqwest::blocking::Client::new();
    Ok(client
        .get(&endpoint_url)
        .header(reqwest::header::AUTHORIZATION, format!("Bearer {}", access_token))
        .send()?
        .text()?)
}

pub fn get_data_jq<T: for<'de> Deserialize<'de>>(conf: &Config, prefix: &str, endpoint: &str, param: String, token: &BasicTokenResponse, multi: bool) -> Result<T, Box<dyn error::Error>> {
    let res: Option<String> = get_optional(&conf, &full_key(vec![prefix, "maps", endpoint]));
    let jq_code = match res {
        Some(s) => match multi {
            true => "map(".to_string() + &s + ")",
            false => s
        },
        None => ".".to_string()
    };
    let mut jq_prog = jq_rs::compile(&jq_code)?;

    let data_raw = get_data(&conf, prefix, endpoint, param, token)?;
    let data_trans = jq_prog.run(&data_raw)?;

    Ok(serde_json::from_str(&data_trans)?)
}