3
3
# Licensed under the MIT License.
4
4
# ------------------------------------
5
5
import os
6
+ import sys
6
7
import time
7
8
8
9
try :
@@ -883,9 +884,10 @@ def test_azure_arc(tmpdir):
883
884
"os.environ" ,
884
885
{EnvironmentVariables .IDENTITY_ENDPOINT : identity_endpoint , EnvironmentVariables .IMDS_ENDPOINT : imds_endpoint },
885
886
):
886
- token = ManagedIdentityCredential (transport = transport ).get_token (scope )
887
- assert token .token == access_token
888
- assert token .expires_on == expires_on
887
+ with mock .patch ("azure.identity._credentials.azure_arc._validate_key_file" , lambda x : None ):
888
+ token = ManagedIdentityCredential (transport = transport ).get_token (scope )
889
+ assert token .token == access_token
890
+ assert token .expires_on == expires_on
889
891
890
892
891
893
def test_azure_arc_tenant_id (tmpdir ):
@@ -936,9 +938,10 @@ def test_azure_arc_tenant_id(tmpdir):
936
938
"os.environ" ,
937
939
{EnvironmentVariables .IDENTITY_ENDPOINT : identity_endpoint , EnvironmentVariables .IMDS_ENDPOINT : imds_endpoint },
938
940
):
939
- token = ManagedIdentityCredential (transport = transport ).get_token (scope , tenant_id = "tenant_id" )
940
- assert token .token == access_token
941
- assert token .expires_on == expires_on
941
+ with mock .patch ("azure.identity._credentials.azure_arc._validate_key_file" , lambda x : None ):
942
+ token = ManagedIdentityCredential (transport = transport ).get_token (scope , tenant_id = "tenant_id" )
943
+ assert token .token == access_token
944
+ assert token .expires_on == expires_on
942
945
943
946
944
947
def test_azure_arc_client_id ():
@@ -950,10 +953,123 @@ def test_azure_arc_client_id():
950
953
EnvironmentVariables .IMDS_ENDPOINT : "http://localhost:42" ,
951
954
},
952
955
):
953
- credential = ManagedIdentityCredential (client_id = "some-guid" )
956
+ with mock .patch ("azure.identity._credentials.azure_arc._validate_key_file" , lambda x : None ):
957
+ credential = ManagedIdentityCredential (client_id = "some-guid" )
954
958
955
- with pytest .raises (ClientAuthenticationError ):
959
+ with pytest .raises (ClientAuthenticationError ) as ex :
956
960
credential .get_token ("scope" )
961
+ assert "not supported" in str (ex .value )
962
+
963
+
964
+ def test_azure_arc_key_too_large (tmp_path ):
965
+
966
+ api_version = "2019-11-01"
967
+ identity_endpoint = "http://localhost:42/token"
968
+ imds_endpoint = "http://localhost:42"
969
+ scope = "scope"
970
+ secret_key = "X" * 4097
971
+
972
+ key_file = tmp_path / "key_file.key"
973
+ key_file .write_text (secret_key )
974
+ assert key_file .read_text () == secret_key
975
+
976
+ transport = validating_transport (
977
+ requests = [
978
+ Request (
979
+ base_url = identity_endpoint ,
980
+ method = "GET" ,
981
+ required_headers = {"Metadata" : "true" },
982
+ required_params = {"api-version" : api_version , "resource" : scope },
983
+ ),
984
+ ],
985
+ responses = [
986
+ mock_response (status_code = 401 , headers = {"WWW-Authenticate" : "Basic realm={}" .format (key_file )}),
987
+ ],
988
+ )
989
+
990
+ with mock .patch (
991
+ "os.environ" ,
992
+ {EnvironmentVariables .IDENTITY_ENDPOINT : identity_endpoint , EnvironmentVariables .IMDS_ENDPOINT : imds_endpoint },
993
+ ):
994
+ with mock .patch ("azure.identity._credentials.azure_arc._get_key_file_path" , lambda : str (tmp_path )):
995
+ with pytest .raises (ClientAuthenticationError ) as ex :
996
+ ManagedIdentityCredential (transport = transport ).get_token (scope )
997
+ assert "file size" in str (ex .value )
998
+
999
+
1000
+ def test_azure_arc_key_not_exist (tmp_path ):
1001
+
1002
+ api_version = "2019-11-01"
1003
+ identity_endpoint = "http://localhost:42/token"
1004
+ imds_endpoint = "http://localhost:42"
1005
+ scope = "scope"
1006
+
1007
+ transport = validating_transport (
1008
+ requests = [
1009
+ Request (
1010
+ base_url = identity_endpoint ,
1011
+ method = "GET" ,
1012
+ required_headers = {"Metadata" : "true" },
1013
+ required_params = {"api-version" : api_version , "resource" : scope },
1014
+ ),
1015
+ ],
1016
+ responses = [
1017
+ mock_response (status_code = 401 , headers = {"WWW-Authenticate" : "Basic realm=/path/to/key_file" }),
1018
+ ],
1019
+ )
1020
+
1021
+ with mock .patch (
1022
+ "os.environ" ,
1023
+ {EnvironmentVariables .IDENTITY_ENDPOINT : identity_endpoint , EnvironmentVariables .IMDS_ENDPOINT : imds_endpoint },
1024
+ ):
1025
+ with pytest .raises (ClientAuthenticationError ) as ex :
1026
+ ManagedIdentityCredential (transport = transport ).get_token (scope )
1027
+ assert "not exist" in str (ex .value )
1028
+
1029
+
1030
+ def test_azure_arc_key_invalid (tmp_path ):
1031
+
1032
+ api_version = "2019-11-01"
1033
+ identity_endpoint = "http://localhost:42/token"
1034
+ imds_endpoint = "http://localhost:42"
1035
+ scope = "scope"
1036
+ key_file = tmp_path / "key_file.txt"
1037
+ key_file .write_text ("secret" )
1038
+
1039
+ transport = validating_transport (
1040
+ requests = [
1041
+ Request (
1042
+ base_url = identity_endpoint ,
1043
+ method = "GET" ,
1044
+ required_headers = {"Metadata" : "true" },
1045
+ required_params = {"api-version" : api_version , "resource" : scope },
1046
+ ),
1047
+ Request (
1048
+ base_url = identity_endpoint ,
1049
+ method = "GET" ,
1050
+ required_headers = {"Metadata" : "true" },
1051
+ required_params = {"api-version" : api_version , "resource" : scope },
1052
+ ),
1053
+ ],
1054
+ responses = [
1055
+ mock_response (status_code = 401 , headers = {"WWW-Authenticate" : "Basic realm={}" .format (key_file )}),
1056
+ mock_response (status_code = 401 , headers = {"WWW-Authenticate" : "Basic realm={}" .format (key_file )}),
1057
+ ],
1058
+ )
1059
+
1060
+ with mock .patch (
1061
+ "os.environ" ,
1062
+ {EnvironmentVariables .IDENTITY_ENDPOINT : identity_endpoint , EnvironmentVariables .IMDS_ENDPOINT : imds_endpoint },
1063
+ ):
1064
+ with mock .patch ("azure.identity._credentials.azure_arc._get_key_file_path" , lambda : "/foo" ):
1065
+ with pytest .raises (ClientAuthenticationError ) as ex :
1066
+ ManagedIdentityCredential (transport = transport ).get_token (scope )
1067
+ assert "Unexpected file path" in str (ex .value )
1068
+
1069
+ with mock .patch ("azure.identity._credentials.azure_arc._get_key_file_path" , lambda : str (tmp_path )):
1070
+ with pytest .raises (ClientAuthenticationError ) as ex :
1071
+ ManagedIdentityCredential (transport = transport ).get_token (scope )
1072
+ assert "extension" in str (ex .value )
957
1073
958
1074
959
1075
def test_token_exchange (tmpdir ):
0 commit comments