diff --git a/nss_pam_oidc/config.py b/nss_pam_oidc/config.py index 690c411c3b3902caf5842ac7e63f7ad658db8a49..62f12b198986111d608f2a1e622be3ada34e931c 100644 --- a/nss_pam_oidc/config.py +++ b/nss_pam_oidc/config.py @@ -18,13 +18,15 @@ import toml DEFAULT_CONFIG_FILE = "/etc/nss_pam_oidc.toml" DEFAULT_CONFIG = { - "pam": {}, + "pam": { + "flow": "password", + }, "nss": {}, } -def get_config(section: str, **kwargs): - """Get configuration for a section, layered from defaults, config file, and args.""" +def load_config(config_file: str) -> dict[str, dict[str, str]]: + """Load full configuration, amended wit hdefault config.""" config = {} # First, copy default configuration @@ -32,13 +34,25 @@ def get_config(section: str, **kwargs): config[name] = conf.copy() # Second, load configuration from file - config_file = kwargs.pop("config", DEFAULT_CONFIG_FILE) if os.path.exists(config_file): config_toml = toml.load(config_file) for name, conf in config.items(): conf.update(config_toml.get(name, {})) + return config + + +def get_config(section: str, **kwargs) -> dict[str, str]: + """Get configuration for a section, layered from defaults, config file, and args.""" + config_file = kwargs.pop("config", DEFAULT_CONFIG_FILE) + config = load_config(config_file) + # Lastly, override with passed arguments config[section].update(kwargs) - return config + return config[section] + + +def filter_config(config: dict[str, str], keys: list[str]) -> dict[str, str]: + """Return a copy of the config dictionary with only the requested keys.""" + return {k: v for k, v in config.items() if k in keys} diff --git a/nss_pam_oidc/pam.py b/nss_pam_oidc/pam.py index f0d51814b58f6aa2804d76dd7c3978c0cda537b6..4353e6eb406080081d6b4bc03538b8456836fa50 100644 --- a/nss_pam_oidc/pam.py +++ b/nss_pam_oidc/pam.py @@ -12,26 +12,79 @@ # See the License for the specific language governing permissions and # limitations under the License. +from oauthlib.oauth2 import InvalidGrantError, LegacyApplicationClient +from requests.exceptions import RequestException +from requests_oauthlib import OAuth2Session + +from .config import filter_config, get_config + + +def _split_argv(argv: list[str]) -> dict[str, str]: + args = {} + for arg in argv: + if "=" in arg: + name, val = arg.split("=", 1) + else: + name, val = arg, True + args[name] = val + return args + + +def _do_legacy_auth(username: str, password: str, config: dict[str, str]): + client_config = filter_config(config, ["client_id"]) + session_config = filter_config(config, ["client_id"]) + fetch_config = filter_config(config, ["client_id", "client_secret", "token_url"]) + + client = LegacyApplicationClient(**client_config) + session = OAuth2Session(client=client, **session_config) + + token = session.fetch_token(username=username, password=password, **fetch_config) + return token + + def pam_sm_authenticate(pamh, flags, argv): - try: - user = pamh.get_user(None) - except pamh.exception as e: - return e.pam_result - if user == None: - pamh.user = DEFAULT_USER - return pamh.PAM_SUCCESS + args = _split_argv(argv) + config = get_config("pam", args) + + if config.pop("flow") == "password": + try: + user = pamh.get_user(None) + password = pamh.authtok + except pamh.exception as e: + return e.pam_result + if user is None or password is None: + return pamh.PAM_CRED_INSUFFICIENT + + try: + token = _do_legacy_auth(username, password) + except InvalidGrantError: + return pamh.PAM_AUTH_ERROR + except RequestException: + return pamh.PAM_AUTHINFO_UNAVAIL + except: + return pamh.PAM_SERVICE_ERR + + if "access_token" in token: + return pamh.PAM_SUCCESS + + return pamh.PAM_AUTH_ERR + def pam_sm_setcred(pamh, flags, argv): - return pamh.PAM_SUCCESS + return pamh.PAM_SUCCESS + def pam_sm_acct_mgmt(pamh, flags, argv): - return pamh.PAM_SUCCESS + return pamh.PAM_SUCCESS + def pam_sm_open_session(pamh, flags, argv): - return pamh.PAM_SUCCESS + return pamh.PAM_SUCCESS + def pam_sm_close_session(pamh, flags, argv): - return pamh.PAM_SUCCESS + return pamh.PAM_SUCCESS + def pam_sm_chauthtok(pamh, flags, argv): - return pamh.PAM_SUCCESS + return pamh.PAM_SUCCESS diff --git a/pyproject.toml b/pyproject.toml index 80cab81169803b8faccdae651de29cf927fefcd7..b97730d6b35f11d46f0572bea95f3a53e9c2910e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,8 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" toml = "^0.10.2" +oauthlib = "^3.1.0" +requests-oauthlib = "^1.3.0" [tool.poetry.dev-dependencies]