stable-diffusion-implementation
/
main
/myenv
/lib
/python3.10
/site-packages
/huggingface_hub
/_oauth.py
import datetime | |
import hashlib | |
import logging | |
import os | |
import time | |
import urllib.parse | |
import warnings | |
from dataclasses import dataclass | |
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple, Union | |
from . import constants | |
from .hf_api import whoami | |
from .utils import experimental, get_token | |
logger = logging.getLogger(__name__) | |
if TYPE_CHECKING: | |
import fastapi | |
class OAuthOrgInfo: | |
""" | |
Information about an organization linked to a user logged in with OAuth. | |
Attributes: | |
sub (`str`): | |
Unique identifier for the org. OpenID Connect field. | |
name (`str`): | |
The org's full name. OpenID Connect field. | |
preferred_username (`str`): | |
The org's username. OpenID Connect field. | |
picture (`str`): | |
The org's profile picture URL. OpenID Connect field. | |
is_enterprise (`bool`): | |
Whether the org is an enterprise org. Hugging Face field. | |
can_pay (`Optional[bool]`, *optional*): | |
Whether the org has a payment method set up. Hugging Face field. | |
role_in_org (`Optional[str]`, *optional*): | |
The user's role in the org. Hugging Face field. | |
security_restrictions (`Optional[List[Literal["ip", "token-policy", "mfa", "sso"]]]`, *optional*): | |
Array of security restrictions that the user hasn't completed for this org. Possible values: "ip", "token-policy", "mfa", "sso". Hugging Face field. | |
""" | |
sub: str | |
name: str | |
preferred_username: str | |
picture: str | |
is_enterprise: bool | |
can_pay: Optional[bool] = None | |
role_in_org: Optional[str] = None | |
security_restrictions: Optional[List[Literal["ip", "token-policy", "mfa", "sso"]]] = None | |
class OAuthUserInfo: | |
""" | |
Information about a user logged in with OAuth. | |
Attributes: | |
sub (`str`): | |
Unique identifier for the user, even in case of rename. OpenID Connect field. | |
name (`str`): | |
The user's full name. OpenID Connect field. | |
preferred_username (`str`): | |
The user's username. OpenID Connect field. | |
email_verified (`Optional[bool]`, *optional*): | |
Indicates if the user's email is verified. OpenID Connect field. | |
email (`Optional[str]`, *optional*): | |
The user's email address. OpenID Connect field. | |
picture (`str`): | |
The user's profile picture URL. OpenID Connect field. | |
profile (`str`): | |
The user's profile URL. OpenID Connect field. | |
website (`Optional[str]`, *optional*): | |
The user's website URL. OpenID Connect field. | |
is_pro (`bool`): | |
Whether the user is a pro user. Hugging Face field. | |
can_pay (`Optional[bool]`, *optional*): | |
Whether the user has a payment method set up. Hugging Face field. | |
orgs (`Optional[List[OrgInfo]]`, *optional*): | |
List of organizations the user is part of. Hugging Face field. | |
""" | |
sub: str | |
name: str | |
preferred_username: str | |
email_verified: Optional[bool] | |
email: Optional[str] | |
picture: str | |
profile: str | |
website: Optional[str] | |
is_pro: bool | |
can_pay: Optional[bool] | |
orgs: Optional[List[OAuthOrgInfo]] | |
class OAuthInfo: | |
""" | |
Information about the OAuth login. | |
Attributes: | |
access_token (`str`): | |
The access token. | |
access_token_expires_at (`datetime.datetime`): | |
The expiration date of the access token. | |
user_info ([`OAuthUserInfo`]): | |
The user information. | |
state (`str`, *optional*): | |
State passed to the OAuth provider in the original request to the OAuth provider. | |
scope (`str`): | |
Granted scope. | |
""" | |
access_token: str | |
access_token_expires_at: datetime.datetime | |
user_info: OAuthUserInfo | |
state: Optional[str] | |
scope: str | |
def attach_huggingface_oauth(app: "fastapi.FastAPI", route_prefix: str = "/"): | |
""" | |
Add OAuth endpoints to a FastAPI app to enable OAuth login with Hugging Face. | |
How to use: | |
- Call this method on your FastAPI app to add the OAuth endpoints. | |
- Inside your route handlers, call `parse_huggingface_oauth(request)` to retrieve the OAuth info. | |
- If user is logged in, an [`OAuthInfo`] object is returned with the user's info. If not, `None` is returned. | |
- In your app, make sure to add links to `/oauth/huggingface/login` and `/oauth/huggingface/logout` for the user to log in and out. | |
Example: | |
```py | |
from huggingface_hub import attach_huggingface_oauth, parse_huggingface_oauth | |
# Create a FastAPI app | |
app = FastAPI() | |
# Add OAuth endpoints to the FastAPI app | |
attach_huggingface_oauth(app) | |
# Add a route that greets the user if they are logged in | |
@app.get("/") | |
def greet_json(request: Request): | |
# Retrieve the OAuth info from the request | |
oauth_info = parse_huggingface_oauth(request) # e.g. OAuthInfo dataclass | |
if oauth_info is None: | |
return {"msg": "Not logged in!"} | |
return {"msg": f"Hello, {oauth_info.user_info.preferred_username}!"} | |
``` | |
""" | |
# TODO: handle generic case (handling OAuth in a non-Space environment with custom dev values) (low priority) | |
# Add SessionMiddleware to the FastAPI app to store the OAuth info in the session. | |
# Session Middleware requires a secret key to sign the cookies. Let's use a hash | |
# of the OAuth secret key to make it unique to the Space + updated in case OAuth | |
# config gets updated. When ran locally, we use an empty string as a secret key. | |
try: | |
from starlette.middleware.sessions import SessionMiddleware | |
except ImportError as e: | |
raise ImportError( | |
"Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " | |
"`huggingface_hub[oauth]` to your requirements.txt file in order to install the required dependencies." | |
) from e | |
session_secret = (constants.OAUTH_CLIENT_SECRET or "") + "-v1" | |
app.add_middleware( | |
SessionMiddleware, # type: ignore[arg-type] | |
secret_key=hashlib.sha256(session_secret.encode()).hexdigest(), | |
same_site="none", | |
https_only=True, | |
) # type: ignore | |
# Add OAuth endpoints to the FastAPI app: | |
# - {route_prefix}/oauth/huggingface/login | |
# - {route_prefix}/oauth/huggingface/callback | |
# - {route_prefix}/oauth/huggingface/logout | |
# If the app is running in a Space, OAuth is enabled normally. | |
# Otherwise, we mock the endpoints to make the user log in with a fake user profile - without any calls to hf.co. | |
route_prefix = route_prefix.strip("/") | |
if os.getenv("SPACE_ID") is not None: | |
logger.info("OAuth is enabled in the Space. Adding OAuth routes.") | |
_add_oauth_routes(app, route_prefix=route_prefix) | |
else: | |
logger.info("App is not running in a Space. Adding mocked OAuth routes.") | |
_add_mocked_oauth_routes(app, route_prefix=route_prefix) | |
def parse_huggingface_oauth(request: "fastapi.Request") -> Optional[OAuthInfo]: | |
""" | |
Returns the information from a logged in user as a [`OAuthInfo`] object. | |
For flexibility and future-proofing, this method is very lax in its parsing and does not raise errors. | |
Missing fields are set to `None` without a warning. | |
Return `None`, if the user is not logged in (no info in session cookie). | |
See [`attach_huggingface_oauth`] for an example on how to use this method. | |
""" | |
if "oauth_info" not in request.session: | |
logger.debug("No OAuth info in session.") | |
return None | |
logger.debug("Parsing OAuth info from session.") | |
oauth_data = request.session["oauth_info"] | |
user_data = oauth_data.get("userinfo", {}) | |
orgs_data = user_data.get("orgs", []) | |
orgs = ( | |
[ | |
OAuthOrgInfo( | |
sub=org.get("sub"), | |
name=org.get("name"), | |
preferred_username=org.get("preferred_username"), | |
picture=org.get("picture"), | |
is_enterprise=org.get("isEnterprise"), | |
can_pay=org.get("canPay"), | |
role_in_org=org.get("roleInOrg"), | |
security_restrictions=org.get("securityRestrictions"), | |
) | |
for org in orgs_data | |
] | |
if orgs_data | |
else None | |
) | |
user_info = OAuthUserInfo( | |
sub=user_data.get("sub"), | |
name=user_data.get("name"), | |
preferred_username=user_data.get("preferred_username"), | |
email_verified=user_data.get("email_verified"), | |
email=user_data.get("email"), | |
picture=user_data.get("picture"), | |
profile=user_data.get("profile"), | |
website=user_data.get("website"), | |
is_pro=user_data.get("isPro"), | |
can_pay=user_data.get("canPay"), | |
orgs=orgs, | |
) | |
return OAuthInfo( | |
access_token=oauth_data.get("access_token"), | |
access_token_expires_at=datetime.datetime.fromtimestamp(oauth_data.get("expires_at")), | |
user_info=user_info, | |
state=oauth_data.get("state"), | |
scope=oauth_data.get("scope"), | |
) | |
def _add_oauth_routes(app: "fastapi.FastAPI", route_prefix: str) -> None: | |
"""Add OAuth routes to the FastAPI app (login, callback handler and logout).""" | |
try: | |
import fastapi | |
from authlib.integrations.base_client.errors import MismatchingStateError | |
from authlib.integrations.starlette_client import OAuth | |
from fastapi.responses import RedirectResponse | |
except ImportError as e: | |
raise ImportError( | |
"Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " | |
"`huggingface_hub[oauth]` to your requirements.txt file." | |
) from e | |
# Check environment variables | |
msg = ( | |
"OAuth is required but '{}' environment variable is not set. Make sure you've enabled OAuth in your Space by" | |
" setting `hf_oauth: true` in the Space metadata." | |
) | |
if constants.OAUTH_CLIENT_ID is None: | |
raise ValueError(msg.format("OAUTH_CLIENT_ID")) | |
if constants.OAUTH_CLIENT_SECRET is None: | |
raise ValueError(msg.format("OAUTH_CLIENT_SECRET")) | |
if constants.OAUTH_SCOPES is None: | |
raise ValueError(msg.format("OAUTH_SCOPES")) | |
if constants.OPENID_PROVIDER_URL is None: | |
raise ValueError(msg.format("OPENID_PROVIDER_URL")) | |
# Register OAuth server | |
oauth = OAuth() | |
oauth.register( | |
name="huggingface", | |
client_id=constants.OAUTH_CLIENT_ID, | |
client_secret=constants.OAUTH_CLIENT_SECRET, | |
client_kwargs={"scope": constants.OAUTH_SCOPES}, | |
server_metadata_url=constants.OPENID_PROVIDER_URL + "/.well-known/openid-configuration", | |
) | |
login_uri, callback_uri, logout_uri = _get_oauth_uris(route_prefix) | |
# Register OAuth endpoints | |
async def oauth_login(request: fastapi.Request) -> RedirectResponse: | |
"""Endpoint that redirects to HF OAuth page.""" | |
redirect_uri = _generate_redirect_uri(request) | |
return await oauth.huggingface.authorize_redirect(request, redirect_uri) # type: ignore | |
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: | |
"""Endpoint that handles the OAuth callback.""" | |
try: | |
oauth_info = await oauth.huggingface.authorize_access_token(request) # type: ignore | |
except MismatchingStateError: | |
# Parse query params | |
nb_redirects = int(request.query_params.get("_nb_redirects", 0)) | |
target_url = request.query_params.get("_target_url") | |
# Build redirect URI with the same query params as before and bump nb_redirects count | |
query_params: Dict[str, Union[int, str]] = {"_nb_redirects": nb_redirects + 1} | |
if target_url: | |
query_params["_target_url"] = target_url | |
redirect_uri = f"{login_uri}?{urllib.parse.urlencode(query_params)}" | |
# If the user is redirected more than 3 times, it is very likely that the cookie is not working properly. | |
# (e.g. browser is blocking third-party cookies in iframe). In this case, redirect the user in the | |
# non-iframe view. | |
if nb_redirects > constants.OAUTH_MAX_REDIRECTS: | |
host = os.environ.get("SPACE_HOST") | |
if host is None: # cannot happen in a Space | |
raise RuntimeError( | |
"App is not running in a Space (SPACE_HOST environment variable is not set). Cannot redirect to non-iframe view." | |
) from None | |
host_url = "https://" + host.rstrip("/") | |
return RedirectResponse(host_url + redirect_uri) | |
# Redirect the user to the login page again | |
return RedirectResponse(redirect_uri) | |
# OAuth login worked => store the user info in the session and redirect | |
logger.debug("Successfully logged in with OAuth. Storing user info in session.") | |
request.session["oauth_info"] = oauth_info | |
return RedirectResponse(_get_redirect_target(request)) | |
async def oauth_logout(request: fastapi.Request) -> RedirectResponse: | |
"""Endpoint that logs out the user (e.g. delete info from cookie session).""" | |
logger.debug("Logged out with OAuth. Removing user info from session.") | |
request.session.pop("oauth_info", None) | |
return RedirectResponse(_get_redirect_target(request)) | |
def _add_mocked_oauth_routes(app: "fastapi.FastAPI", route_prefix: str = "/") -> None: | |
"""Add fake oauth routes if app is run locally and OAuth is enabled. | |
Using OAuth will have the same behavior as in a Space but instead of authenticating with HF, a mocked user profile | |
is added to the session. | |
""" | |
try: | |
import fastapi | |
from fastapi.responses import RedirectResponse | |
from starlette.datastructures import URL | |
except ImportError as e: | |
raise ImportError( | |
"Cannot initialize OAuth to due a missing library. Please run `pip install huggingface_hub[oauth]` or add " | |
"`huggingface_hub[oauth]` to your requirements.txt file." | |
) from e | |
warnings.warn( | |
"OAuth is not supported outside of a Space environment. To help you debug your app locally, the oauth endpoints" | |
" are mocked to return your profile and token. To make it work, your machine must be logged in to Huggingface." | |
) | |
mocked_oauth_info = _get_mocked_oauth_info() | |
login_uri, callback_uri, logout_uri = _get_oauth_uris(route_prefix) | |
# Define OAuth routes | |
async def oauth_login(request: fastapi.Request) -> RedirectResponse: | |
"""Fake endpoint that redirects to HF OAuth page.""" | |
# Define target (where to redirect after login) | |
redirect_uri = _generate_redirect_uri(request) | |
return RedirectResponse(callback_uri + "?" + urllib.parse.urlencode({"_target_url": redirect_uri})) | |
async def oauth_redirect_callback(request: fastapi.Request) -> RedirectResponse: | |
"""Endpoint that handles the OAuth callback.""" | |
request.session["oauth_info"] = mocked_oauth_info | |
return RedirectResponse(_get_redirect_target(request)) | |
async def oauth_logout(request: fastapi.Request) -> RedirectResponse: | |
"""Endpoint that logs out the user (e.g. delete cookie session).""" | |
request.session.pop("oauth_info", None) | |
logout_url = URL("/").include_query_params(**request.query_params) | |
return RedirectResponse(url=logout_url, status_code=302) # see https://github.com/gradio-app/gradio/pull/9659 | |
def _generate_redirect_uri(request: "fastapi.Request") -> str: | |
if "_target_url" in request.query_params: | |
# if `_target_url` already in query params => respect it | |
target = request.query_params["_target_url"] | |
else: | |
# otherwise => keep query params | |
target = "/?" + urllib.parse.urlencode(request.query_params) | |
redirect_uri = request.url_for("oauth_redirect_callback").include_query_params(_target_url=target) | |
redirect_uri_as_str = str(redirect_uri) | |
if redirect_uri.netloc.endswith(".hf.space"): | |
# In Space, FastAPI redirect as http but we want https | |
redirect_uri_as_str = redirect_uri_as_str.replace("http://", "https://") | |
return redirect_uri_as_str | |
def _get_redirect_target(request: "fastapi.Request", default_target: str = "/") -> str: | |
return request.query_params.get("_target_url", default_target) | |
def _get_mocked_oauth_info() -> Dict: | |
token = get_token() | |
if token is None: | |
raise ValueError( | |
"Your machine must be logged in to HF to debug an OAuth app locally. Please" | |
" run `hf auth login` or set `HF_TOKEN` as environment variable " | |
"with one of your access token. You can generate a new token in your " | |
"settings page (https://huggingface.co/settings/tokens)." | |
) | |
user = whoami() | |
if user["type"] != "user": | |
raise ValueError( | |
"Your machine is not logged in with a personal account. Please use a " | |
"personal access token. You can generate a new token in your settings page" | |
" (https://huggingface.co/settings/tokens)." | |
) | |
return { | |
"access_token": token, | |
"token_type": "bearer", | |
"expires_in": 8 * 60 * 60, # 8 hours | |
"id_token": "FOOBAR", | |
"scope": "openid profile", | |
"refresh_token": "hf_oauth__refresh_token", | |
"expires_at": int(time.time()) + 8 * 60 * 60, # 8 hours | |
"userinfo": { | |
"sub": "0123456789", | |
"name": user["fullname"], | |
"preferred_username": user["name"], | |
"profile": f"https://huggingface.co/{user['name']}", | |
"picture": user["avatarUrl"], | |
"website": "", | |
"aud": "00000000-0000-0000-0000-000000000000", | |
"auth_time": 1691672844, | |
"nonce": "aaaaaaaaaaaaaaaaaaa", | |
"iat": 1691672844, | |
"exp": 1691676444, | |
"iss": "https://huggingface.co", | |
}, | |
} | |
def _get_oauth_uris(route_prefix: str = "/") -> Tuple[str, str, str]: | |
route_prefix = route_prefix.strip("/") | |
if route_prefix: | |
route_prefix = f"/{route_prefix}" | |
return ( | |
f"{route_prefix}/oauth/huggingface/login", | |
f"{route_prefix}/oauth/huggingface/callback", | |
f"{route_prefix}/oauth/huggingface/logout", | |
) | |