oidc.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. package providers
  2. import (
  3. "encoding/json"
  4. "errors"
  5. "fmt"
  6. "io/ioutil"
  7. "net/http"
  8. "os"
  9. "strconv"
  10. "strings"
  11. "time"
  12. httputil "github.com/aliyun/credentials-go/credentials/internal/http"
  13. "github.com/aliyun/credentials-go/credentials/internal/utils"
  14. )
  15. type OIDCCredentialsProvider struct {
  16. oidcProviderARN string
  17. oidcTokenFilePath string
  18. roleArn string
  19. roleSessionName string
  20. durationSeconds int
  21. policy string
  22. // for sts endpoint
  23. stsRegionId string
  24. enableVpc bool
  25. stsEndpoint string
  26. lastUpdateTimestamp int64
  27. expirationTimestamp int64
  28. sessionCredentials *sessionCredentials
  29. // for http options
  30. httpOptions *HttpOptions
  31. }
  32. type OIDCCredentialsProviderBuilder struct {
  33. provider *OIDCCredentialsProvider
  34. }
  35. func NewOIDCCredentialsProviderBuilder() *OIDCCredentialsProviderBuilder {
  36. return &OIDCCredentialsProviderBuilder{
  37. provider: &OIDCCredentialsProvider{},
  38. }
  39. }
  40. func (b *OIDCCredentialsProviderBuilder) WithOIDCProviderARN(oidcProviderArn string) *OIDCCredentialsProviderBuilder {
  41. b.provider.oidcProviderARN = oidcProviderArn
  42. return b
  43. }
  44. func (b *OIDCCredentialsProviderBuilder) WithOIDCTokenFilePath(oidcTokenFilePath string) *OIDCCredentialsProviderBuilder {
  45. b.provider.oidcTokenFilePath = oidcTokenFilePath
  46. return b
  47. }
  48. func (b *OIDCCredentialsProviderBuilder) WithRoleArn(roleArn string) *OIDCCredentialsProviderBuilder {
  49. b.provider.roleArn = roleArn
  50. return b
  51. }
  52. func (b *OIDCCredentialsProviderBuilder) WithRoleSessionName(roleSessionName string) *OIDCCredentialsProviderBuilder {
  53. b.provider.roleSessionName = roleSessionName
  54. return b
  55. }
  56. func (b *OIDCCredentialsProviderBuilder) WithDurationSeconds(durationSeconds int) *OIDCCredentialsProviderBuilder {
  57. b.provider.durationSeconds = durationSeconds
  58. return b
  59. }
  60. func (b *OIDCCredentialsProviderBuilder) WithStsRegionId(regionId string) *OIDCCredentialsProviderBuilder {
  61. b.provider.stsRegionId = regionId
  62. return b
  63. }
  64. func (b *OIDCCredentialsProviderBuilder) WithEnableVpc(enableVpc bool) *OIDCCredentialsProviderBuilder {
  65. b.provider.enableVpc = enableVpc
  66. return b
  67. }
  68. func (b *OIDCCredentialsProviderBuilder) WithPolicy(policy string) *OIDCCredentialsProviderBuilder {
  69. b.provider.policy = policy
  70. return b
  71. }
  72. func (b *OIDCCredentialsProviderBuilder) WithSTSEndpoint(stsEndpoint string) *OIDCCredentialsProviderBuilder {
  73. b.provider.stsEndpoint = stsEndpoint
  74. return b
  75. }
  76. func (b *OIDCCredentialsProviderBuilder) WithHttpOptions(httpOptions *HttpOptions) *OIDCCredentialsProviderBuilder {
  77. b.provider.httpOptions = httpOptions
  78. return b
  79. }
  80. func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvider, err error) {
  81. if b.provider.roleSessionName == "" {
  82. b.provider.roleSessionName = "credentials-go-" + strconv.FormatInt(time.Now().UnixNano()/1000, 10)
  83. }
  84. if b.provider.oidcTokenFilePath == "" {
  85. b.provider.oidcTokenFilePath = os.Getenv("ALIBABA_CLOUD_OIDC_TOKEN_FILE")
  86. }
  87. if b.provider.oidcTokenFilePath == "" {
  88. err = errors.New("the OIDCTokenFilePath is empty")
  89. return
  90. }
  91. if b.provider.oidcProviderARN == "" {
  92. b.provider.oidcProviderARN = os.Getenv("ALIBABA_CLOUD_OIDC_PROVIDER_ARN")
  93. }
  94. if b.provider.oidcProviderARN == "" {
  95. err = errors.New("the OIDCProviderARN is empty")
  96. return
  97. }
  98. if b.provider.roleArn == "" {
  99. b.provider.roleArn = os.Getenv("ALIBABA_CLOUD_ROLE_ARN")
  100. }
  101. if b.provider.roleArn == "" {
  102. err = errors.New("the RoleArn is empty")
  103. return
  104. }
  105. if b.provider.durationSeconds == 0 {
  106. b.provider.durationSeconds = 3600
  107. }
  108. if b.provider.durationSeconds < 900 {
  109. err = errors.New("the Assume Role session duration should be in the range of 15min - max duration seconds")
  110. }
  111. if b.provider.stsEndpoint == "" {
  112. if !b.provider.enableVpc {
  113. b.provider.enableVpc = strings.ToLower(os.Getenv("ALIBABA_CLOUD_VPC_ENDPOINT_ENABLED")) == "true"
  114. }
  115. prefix := "sts"
  116. if b.provider.enableVpc {
  117. prefix = "sts-vpc"
  118. }
  119. if b.provider.stsRegionId != "" {
  120. b.provider.stsEndpoint = fmt.Sprintf("%s.%s.aliyuncs.com", prefix, b.provider.stsRegionId)
  121. } else if region := os.Getenv("ALIBABA_CLOUD_STS_REGION"); region != "" {
  122. b.provider.stsEndpoint = fmt.Sprintf("%s.%s.aliyuncs.com", prefix, region)
  123. } else {
  124. b.provider.stsEndpoint = "sts.aliyuncs.com"
  125. }
  126. }
  127. provider = b.provider
  128. return
  129. }
  130. func (provider *OIDCCredentialsProvider) getCredentials() (session *sessionCredentials, err error) {
  131. req := &httputil.Request{
  132. Method: "POST",
  133. Protocol: "https",
  134. Host: provider.stsEndpoint,
  135. Headers: map[string]string{},
  136. }
  137. connectTimeout := 5 * time.Second
  138. readTimeout := 10 * time.Second
  139. if provider.httpOptions != nil && provider.httpOptions.ConnectTimeout > 0 {
  140. connectTimeout = time.Duration(provider.httpOptions.ConnectTimeout) * time.Millisecond
  141. }
  142. if provider.httpOptions != nil && provider.httpOptions.ReadTimeout > 0 {
  143. readTimeout = time.Duration(provider.httpOptions.ReadTimeout) * time.Millisecond
  144. }
  145. if provider.httpOptions != nil && provider.httpOptions.Proxy != "" {
  146. req.Proxy = provider.httpOptions.Proxy
  147. }
  148. req.ConnectTimeout = connectTimeout
  149. req.ReadTimeout = readTimeout
  150. queries := make(map[string]string)
  151. queries["Version"] = "2015-04-01"
  152. queries["Action"] = "AssumeRoleWithOIDC"
  153. queries["Format"] = "JSON"
  154. queries["Timestamp"] = utils.GetTimeInFormatISO8601()
  155. req.Queries = queries
  156. bodyForm := make(map[string]string)
  157. bodyForm["RoleArn"] = provider.roleArn
  158. bodyForm["OIDCProviderArn"] = provider.oidcProviderARN
  159. token, err := ioutil.ReadFile(provider.oidcTokenFilePath)
  160. if err != nil {
  161. return
  162. }
  163. bodyForm["OIDCToken"] = string(token)
  164. if provider.policy != "" {
  165. bodyForm["Policy"] = provider.policy
  166. }
  167. bodyForm["RoleSessionName"] = provider.roleSessionName
  168. bodyForm["DurationSeconds"] = strconv.Itoa(provider.durationSeconds)
  169. req.Form = bodyForm
  170. // set headers
  171. req.Headers["Accept-Encoding"] = "identity"
  172. res, err := httpDo(req)
  173. if err != nil {
  174. return
  175. }
  176. if res.StatusCode != http.StatusOK {
  177. message := "get session token failed: "
  178. err = errors.New(message + string(res.Body))
  179. return
  180. }
  181. var data assumeRoleResponse
  182. err = json.Unmarshal(res.Body, &data)
  183. if err != nil {
  184. err = fmt.Errorf("get oidc sts token err, json.Unmarshal fail: %s", err.Error())
  185. return
  186. }
  187. if data.Credentials == nil {
  188. err = fmt.Errorf("get oidc sts token err, fail to get credentials")
  189. return
  190. }
  191. if data.Credentials.AccessKeyId == nil || data.Credentials.AccessKeySecret == nil || data.Credentials.SecurityToken == nil {
  192. err = fmt.Errorf("refresh RoleArn sts token err, fail to get credentials")
  193. return
  194. }
  195. session = &sessionCredentials{
  196. AccessKeyId: *data.Credentials.AccessKeyId,
  197. AccessKeySecret: *data.Credentials.AccessKeySecret,
  198. SecurityToken: *data.Credentials.SecurityToken,
  199. Expiration: *data.Credentials.Expiration,
  200. }
  201. return
  202. }
  203. func (provider *OIDCCredentialsProvider) needUpdateCredential() (result bool) {
  204. if provider.expirationTimestamp == 0 {
  205. return true
  206. }
  207. return provider.expirationTimestamp-time.Now().Unix() <= 180
  208. }
  209. func (provider *OIDCCredentialsProvider) GetCredentials() (cc *Credentials, err error) {
  210. if provider.sessionCredentials == nil || provider.needUpdateCredential() {
  211. sessionCredentials, err1 := provider.getCredentials()
  212. if err1 != nil {
  213. return nil, err1
  214. }
  215. provider.sessionCredentials = sessionCredentials
  216. expirationTime, err2 := time.Parse("2006-01-02T15:04:05Z", sessionCredentials.Expiration)
  217. if err2 != nil {
  218. return nil, err2
  219. }
  220. provider.lastUpdateTimestamp = time.Now().Unix()
  221. provider.expirationTimestamp = expirationTime.Unix()
  222. }
  223. cc = &Credentials{
  224. AccessKeyId: provider.sessionCredentials.AccessKeyId,
  225. AccessKeySecret: provider.sessionCredentials.AccessKeySecret,
  226. SecurityToken: provider.sessionCredentials.SecurityToken,
  227. ProviderName: provider.GetProviderName(),
  228. }
  229. return
  230. }
  231. func (provider *OIDCCredentialsProvider) GetProviderName() string {
  232. return "oidc_role_arn"
  233. }