Skip to content

Commit 9f9d3b1

Browse files
authored
✨ Feature: add lock file for codegen (#231)
1 parent 20d75c9 commit 9f9d3b1

File tree

6 files changed

+287
-32
lines changed

6 files changed

+287
-32
lines changed

codegen/__init__.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from pathlib import Path
22
import shutil
3-
import sys
43
from typing import Any
54

6-
import httpx
75
from jinja2 import Environment, PackageLoader
6+
import tomlkit
87

98
from .config import Config
109
from .log import logger as logger
@@ -21,10 +20,7 @@
2120
from .parser.schemas import UnionSchema
2221
from .source import get_source
2322

24-
if sys.version_info >= (3, 11):
25-
import tomllib
26-
else:
27-
import tomli as tomllib
23+
LOCK_FILE_NAME = "versions.lock"
2824

2925
env = Environment(
3026
loader=PackageLoader("codegen"),
@@ -46,7 +42,9 @@
4642

4743

4844
def load_config() -> Config:
49-
pyproject = tomllib.loads(Path("./pyproject.toml").read_text(encoding="utf-8"))
45+
pyproject = tomlkit.parse(
46+
Path("./pyproject.toml").read_text(encoding="utf-8")
47+
).unwrap()
5048
config_dict: dict[str, Any] = pyproject.get("tool", {}).get("codegen", {})
5149

5250
return Config.model_validate(config_dict)
@@ -252,6 +250,10 @@ def build_versions(dir: Path, versions: dict[str, str], latest_version: str):
252250
logger.info("Successfully generated versions!")
253251

254252

253+
def build_lock_file(file: Path, lock_data: tomlkit.TOMLDocument):
254+
file.write_text(lock_data.as_string())
255+
256+
255257
def build():
256258
config = load_config()
257259
logger.info(f"Loaded config: {config!r}")
@@ -277,7 +279,8 @@ def build():
277279

278280
for description in config.descriptions:
279281
logger.info(f"Start getting OpenAPI source for {description.identifier}...")
280-
source = get_source(httpx.URL(description.source))
282+
source = get_source(description.source)
283+
description._actual_source = str(source.uri.copy_with(fragment=None))
281284
logger.info(f"Getting schema from {source.uri} succeeded!")
282285

283286
logger.info(f"Start parsing OpenAPI spec for {description.identifier}...")
@@ -353,3 +356,4 @@ def build():
353356
versions[latest_version],
354357
latest_model_names,
355358
)
359+
build_lock_file(config.output_dir / LOCK_FILE_NAME, config.to_lock())

codegen/config.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from typing import Any
33

44
from pydantic import BaseModel, Field
5+
from tomlkit import aot, comment, document, inline_table, item, table
6+
from tomlkit.items import Table
7+
from tomlkit.toml_document import TOMLDocument
58

69

710
class Override(BaseModel):
@@ -13,6 +16,29 @@ class Override(BaseModel):
1316
class VersionedOverride(Override):
1417
target_descriptions: list[str] = Field(default_factory=list)
1518

19+
def to_lock(self) -> Table:
20+
tab = table()
21+
if self.target_descriptions:
22+
tab.append("target_descriptions", item(self.target_descriptions))
23+
if self.class_overrides:
24+
class_tab = table()
25+
for key, value in self.class_overrides.items():
26+
class_tab.append(key, value)
27+
tab.append("class_overrides", class_tab)
28+
if self.field_overrides:
29+
field_tab = table()
30+
for key, value in self.field_overrides.items():
31+
field_tab.append(key, value)
32+
tab.append("field_overrides", field_tab)
33+
if self.schema_overrides:
34+
schema_tab = table()
35+
for key, value in self.schema_overrides.items():
36+
value_table = inline_table()
37+
value_table.update(value)
38+
schema_tab.append(key, value_table)
39+
tab.append("schema_overrides", schema_tab)
40+
return tab
41+
1642

1743
class DescriptionConfig(BaseModel):
1844
version: str
@@ -24,11 +50,31 @@ class DescriptionConfig(BaseModel):
2450
is_latest: bool = False
2551
"""If true, the description will be used as the default description."""
2652
source: str
53+
"""Source link to the description file."""
2754

55+
_actual_source: str | None = None
56+
"""The actual source link after downloading, if applicable."""
2857

29-
class Config(BaseModel):
30-
output_dir: Path
31-
legacy_rest_models: Path
58+
@property
59+
def actual_source(self) -> str:
60+
"""Returns the actual source link after downloading, if applicable."""
61+
return self._actual_source or self.source
62+
63+
def to_lock(self) -> Table:
64+
tab = table()
65+
tab.update(
66+
{
67+
"version": self.version,
68+
"identifier": self.identifier,
69+
"module": self.module,
70+
"is_latest": self.is_latest,
71+
"source": self.actual_source,
72+
}
73+
)
74+
return tab
75+
76+
77+
class GenerationInfo(BaseModel):
3278
descriptions: list[DescriptionConfig]
3379
overrides: list[VersionedOverride] = Field(default_factory=list)
3480

@@ -56,3 +102,24 @@ def get_override_config_for_version(self, version_id: str) -> Override:
56102
for key, value in override.schema_overrides.items()
57103
},
58104
)
105+
106+
def to_lock(self) -> TOMLDocument:
107+
doc = document()
108+
doc.append(None, comment("DO NOT EDIT THIS FILE!"))
109+
doc.append(None, comment("This file is automatically @generated by githubkit."))
110+
111+
descriptions_aot = aot()
112+
for description in self.descriptions:
113+
descriptions_aot.append(description.to_lock())
114+
doc.append("descriptions", descriptions_aot)
115+
116+
overrides_aot = aot()
117+
for override in self.overrides:
118+
overrides_aot.append(override.to_lock())
119+
doc.append("overrides", overrides_aot)
120+
return doc
121+
122+
123+
class Config(GenerationInfo):
124+
output_dir: Path
125+
legacy_rest_models: Path

codegen/source.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
from dataclasses import dataclass
22
from functools import cache
3-
import json
4-
from pathlib import Path
3+
import os
54
from typing import Any
65

76
import httpx
87
from jsonpointer import JsonPointer
98

9+
REPO_COMMIT = os.getenv(
10+
"REPO_COMMIT",
11+
"https://api.github.com/repos/github/rest-api-description/commits/main",
12+
)
13+
RAW_SOURCE_PREFIX = os.getenv(
14+
"RAW_SOURCE_PREFIX",
15+
"https://raw.github.com/github/rest-api-description/",
16+
)
17+
1018

1119
@dataclass(frozen=True)
1220
class Source:
@@ -37,19 +45,32 @@ def __truediv__(self, other: str | int) -> "Source":
3745

3846

3947
@cache
40-
def get_content(source: httpx.URL | Path) -> dict:
41-
return (
42-
json.loads(source.read_text(encoding="utf-8"))
43-
if isinstance(source, Path)
44-
else httpx.get(
45-
source, headers={"User-Agent": "GitHubKit Codegen"}, follow_redirects=True
46-
).json()
48+
def get_content(source: str | httpx.URL) -> tuple[httpx.URL, dict]:
49+
if isinstance(source, str):
50+
sha_response = httpx.get(
51+
REPO_COMMIT,
52+
headers={
53+
"User-Agent": "GitHubKit Codegen",
54+
"Accept": "application/vnd.github.sha",
55+
},
56+
)
57+
sha_response.raise_for_status()
58+
sha = sha_response.text.strip()
59+
source_link = httpx.URL(RAW_SOURCE_PREFIX).join(f"{sha}/{source.lstrip('/')}")
60+
else:
61+
source_link = source
62+
63+
response = httpx.get(
64+
source_link, headers={"User-Agent": "GitHubKit Codegen"}, follow_redirects=True
4765
)
66+
response.raise_for_status()
67+
uri = response.url
68+
content = response.json()
69+
return uri, content
4870

4971

50-
def get_source(source: httpx.URL | Path, path: str | None = None) -> Source:
51-
if isinstance(source, Path):
52-
uri = httpx.URL(source.resolve().as_uri(), fragment=path)
53-
else:
54-
uri = source if path is None else source.copy_with(fragment=path)
55-
return Source(uri=uri, root=get_content(source))
72+
def get_source(source: str | httpx.URL, path: str | None = None) -> Source:
73+
uri, root = get_content(source)
74+
if path is not None:
75+
uri = uri.copy_with(fragment=path)
76+
return Source(uri=uri, root=root)

0 commit comments

Comments
 (0)