Skip to content

Commit cb065ac

Browse files
authored
[Identity] Managed identity bug fix (#36010)
Signed-off-by: Paul Van Eck <paulvaneck@microsoft.com>
1 parent 40cf085 commit cb065ac

File tree

6 files changed

+309
-18
lines changed

6 files changed

+309
-18
lines changed

sdk/identity/azure-identity/CHANGELOG.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
11
# Release History
22

3-
## 1.17.0 (2024-06-11)
3+
## 1.17.0b2 (2024-06-11)
44

55
### Features Added
66

77
- `OnBehalfOfCredential` now supports client assertion callbacks through the `client_assertion_func` keyword argument. This enables authenticating with client assertions such as federated credentials. ([#35812](https://github.com/Azure/azure-sdk-for-python/pull/35812))
88

9+
### Bugs Fixed
10+
11+
- Managed identity bug fixes
12+
13+
## 1.16.1 (2024-06-11)
14+
15+
### Bugs Fixed
16+
17+
- Managed identity bug fixes
18+
919
## 1.17.0b1 (2024-05-13)
1020

1121
### Features Added

sdk/identity/azure-identity/azure/identity/_credentials/azure_arc.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# ------------------------------------
55
import functools
66
import os
7+
import sys
78
from typing import Any, Dict, Optional
89

910
from azure.core.exceptions import ClientAuthenticationError
@@ -24,7 +25,7 @@ def get_client(self, **kwargs: Any) -> Optional[ManagedIdentityClient]:
2425
return ManagedIdentityClient(
2526
_per_retry_policies=[ArcChallengeAuthPolicy()],
2627
request_factory=functools.partial(_get_request, url),
27-
**kwargs
28+
**kwargs,
2829
)
2930
return None
3031

@@ -70,6 +71,12 @@ def _get_secret_key(response: PipelineResponse) -> str:
7071
raise ClientAuthenticationError(
7172
message="Did not receive a correct value from WWW-Authenticate header: {}".format(header)
7273
) from ex
74+
75+
try:
76+
_validate_key_file(key_file)
77+
except ValueError as ex:
78+
raise ClientAuthenticationError(message="The key file path is invalid: {}".format(ex)) from ex
79+
7380
with open(key_file, "r", encoding="utf-8") as file:
7481
try:
7582
return file.read()
@@ -80,6 +87,53 @@ def _get_secret_key(response: PipelineResponse) -> str:
8087
) from error
8188

8289

90+
def _get_key_file_path() -> str:
91+
"""Returns the expected path for the Azure Arc MSI key file based on the current platform.
92+
93+
Only Linux and Windows are supported.
94+
95+
:return: The expected path.
96+
:rtype: str
97+
:raises ValueError: If the current platform is not supported.
98+
"""
99+
if sys.platform.startswith("linux"):
100+
return "/var/opt/azcmagent/tokens"
101+
if sys.platform.startswith("win"):
102+
program_data_path = os.environ.get("PROGRAMDATA")
103+
if not program_data_path:
104+
raise ValueError("PROGRAMDATA environment variable is not set or is empty.")
105+
return os.path.join(f"{program_data_path}", "AzureConnectedMachineAgent", "Tokens")
106+
raise ValueError(f"Azure Arc MSI is not supported on this platform {sys.platform}")
107+
108+
109+
def _validate_key_file(file_path: str) -> None:
110+
"""Validates that a given Azure Arc MSI file path is valid for use.
111+
112+
A valid file will:
113+
1. Be in the expected path for the current platform.
114+
2. Have a `.key` extension.
115+
3. Be at most 4096 bytes in size.
116+
117+
:param str file_path: The path to the key file.
118+
:raises ClientAuthenticationError: If the file path is invalid.
119+
"""
120+
if not file_path:
121+
raise ValueError("The file path must not be empty.")
122+
123+
if not os.path.exists(file_path):
124+
raise ValueError(f"The file path does not exist: {file_path}")
125+
126+
expected_directory = _get_key_file_path()
127+
if not os.path.dirname(file_path) == expected_directory:
128+
raise ValueError(f"Unexpected file path from HIMDS service: {file_path}")
129+
130+
if not file_path.endswith(".key"):
131+
raise ValueError("The file path must have a '.key' extension.")
132+
133+
if os.path.getsize(file_path) > 4096:
134+
raise ValueError("The file size must be less than or equal to 4096 bytes.")
135+
136+
83137
class ArcChallengeAuthPolicy(HTTPPolicy):
84138
"""Policy for handling Azure Arc's challenge authentication"""
85139

sdk/identity/azure-identity/azure/identity/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
# Copyright (c) Microsoft Corporation.
33
# Licensed under the MIT License.
44
# ------------------------------------
5-
VERSION = "1.17.0"
5+
VERSION = "1.17.0b2"

sdk/identity/azure-identity/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
url="https://github.com/Azure/azure-sdk-for-python/tree/main/sdk/identity/azure-identity",
3939
keywords="azure, azure sdk",
4040
classifiers=[
41-
"Development Status :: 5 - Production/Stable",
41+
"Development Status :: 4 - Beta",
4242
"Programming Language :: Python",
4343
"Programming Language :: Python :: 3 :: Only",
4444
"Programming Language :: Python :: 3",

sdk/identity/azure-identity/tests/test_managed_identity.py

Lines changed: 124 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Licensed under the MIT License.
44
# ------------------------------------
55
import os
6+
import sys
67
import time
78

89
try:
@@ -883,9 +884,10 @@ def test_azure_arc(tmpdir):
883884
"os.environ",
884885
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
885886
):
886-
token = ManagedIdentityCredential(transport=transport).get_token(scope)
887-
assert token.token == access_token
888-
assert token.expires_on == expires_on
887+
with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
888+
token = ManagedIdentityCredential(transport=transport).get_token(scope)
889+
assert token.token == access_token
890+
assert token.expires_on == expires_on
889891

890892

891893
def test_azure_arc_tenant_id(tmpdir):
@@ -936,9 +938,10 @@ def test_azure_arc_tenant_id(tmpdir):
936938
"os.environ",
937939
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
938940
):
939-
token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
940-
assert token.token == access_token
941-
assert token.expires_on == expires_on
941+
with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
942+
token = ManagedIdentityCredential(transport=transport).get_token(scope, tenant_id="tenant_id")
943+
assert token.token == access_token
944+
assert token.expires_on == expires_on
942945

943946

944947
def test_azure_arc_client_id():
@@ -950,10 +953,123 @@ def test_azure_arc_client_id():
950953
EnvironmentVariables.IMDS_ENDPOINT: "http://localhost:42",
951954
},
952955
):
953-
credential = ManagedIdentityCredential(client_id="some-guid")
956+
with mock.patch("azure.identity._credentials.azure_arc._validate_key_file", lambda x: None):
957+
credential = ManagedIdentityCredential(client_id="some-guid")
954958

955-
with pytest.raises(ClientAuthenticationError):
959+
with pytest.raises(ClientAuthenticationError) as ex:
956960
credential.get_token("scope")
961+
assert "not supported" in str(ex.value)
962+
963+
964+
def test_azure_arc_key_too_large(tmp_path):
965+
966+
api_version = "2019-11-01"
967+
identity_endpoint = "http://localhost:42/token"
968+
imds_endpoint = "http://localhost:42"
969+
scope = "scope"
970+
secret_key = "X" * 4097
971+
972+
key_file = tmp_path / "key_file.key"
973+
key_file.write_text(secret_key)
974+
assert key_file.read_text() == secret_key
975+
976+
transport = validating_transport(
977+
requests=[
978+
Request(
979+
base_url=identity_endpoint,
980+
method="GET",
981+
required_headers={"Metadata": "true"},
982+
required_params={"api-version": api_version, "resource": scope},
983+
),
984+
],
985+
responses=[
986+
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
987+
],
988+
)
989+
990+
with mock.patch(
991+
"os.environ",
992+
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
993+
):
994+
with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
995+
with pytest.raises(ClientAuthenticationError) as ex:
996+
ManagedIdentityCredential(transport=transport).get_token(scope)
997+
assert "file size" in str(ex.value)
998+
999+
1000+
def test_azure_arc_key_not_exist(tmp_path):
1001+
1002+
api_version = "2019-11-01"
1003+
identity_endpoint = "http://localhost:42/token"
1004+
imds_endpoint = "http://localhost:42"
1005+
scope = "scope"
1006+
1007+
transport = validating_transport(
1008+
requests=[
1009+
Request(
1010+
base_url=identity_endpoint,
1011+
method="GET",
1012+
required_headers={"Metadata": "true"},
1013+
required_params={"api-version": api_version, "resource": scope},
1014+
),
1015+
],
1016+
responses=[
1017+
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm=/path/to/key_file"}),
1018+
],
1019+
)
1020+
1021+
with mock.patch(
1022+
"os.environ",
1023+
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
1024+
):
1025+
with pytest.raises(ClientAuthenticationError) as ex:
1026+
ManagedIdentityCredential(transport=transport).get_token(scope)
1027+
assert "not exist" in str(ex.value)
1028+
1029+
1030+
def test_azure_arc_key_invalid(tmp_path):
1031+
1032+
api_version = "2019-11-01"
1033+
identity_endpoint = "http://localhost:42/token"
1034+
imds_endpoint = "http://localhost:42"
1035+
scope = "scope"
1036+
key_file = tmp_path / "key_file.txt"
1037+
key_file.write_text("secret")
1038+
1039+
transport = validating_transport(
1040+
requests=[
1041+
Request(
1042+
base_url=identity_endpoint,
1043+
method="GET",
1044+
required_headers={"Metadata": "true"},
1045+
required_params={"api-version": api_version, "resource": scope},
1046+
),
1047+
Request(
1048+
base_url=identity_endpoint,
1049+
method="GET",
1050+
required_headers={"Metadata": "true"},
1051+
required_params={"api-version": api_version, "resource": scope},
1052+
),
1053+
],
1054+
responses=[
1055+
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
1056+
mock_response(status_code=401, headers={"WWW-Authenticate": "Basic realm={}".format(key_file)}),
1057+
],
1058+
)
1059+
1060+
with mock.patch(
1061+
"os.environ",
1062+
{EnvironmentVariables.IDENTITY_ENDPOINT: identity_endpoint, EnvironmentVariables.IMDS_ENDPOINT: imds_endpoint},
1063+
):
1064+
with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: "/foo"):
1065+
with pytest.raises(ClientAuthenticationError) as ex:
1066+
ManagedIdentityCredential(transport=transport).get_token(scope)
1067+
assert "Unexpected file path" in str(ex.value)
1068+
1069+
with mock.patch("azure.identity._credentials.azure_arc._get_key_file_path", lambda: str(tmp_path)):
1070+
with pytest.raises(ClientAuthenticationError) as ex:
1071+
ManagedIdentityCredential(transport=transport).get_token(scope)
1072+
assert "extension" in str(ex.value)
9571073

9581074

9591075
def test_token_exchange(tmpdir):

0 commit comments

Comments
 (0)