Skip to content

Commit 7cf6a9a

Browse files
macro: allow flex imports
1 parent 204ddce commit 7cf6a9a

File tree

5 files changed

+302
-68
lines changed

5 files changed

+302
-68
lines changed

sdk-libs/macros/src/compressible.rs

Lines changed: 224 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use syn::{
66
parse::{Parse, ParseStream},
77
punctuated::Punctuated,
88
visit_mut, Attribute, Expr, Field, Ident, Item, ItemEnum, ItemFn, ItemMod, ItemStruct, Result,
9-
Token,
9+
Token, UseTree,
1010
};
1111

1212
/// Parse a comma-separated list of identifiers
@@ -29,6 +29,88 @@ struct SeedInfo {
2929
bump_field: Option<Ident>,
3030
}
3131

32+
/// Information about imported items from use statements
33+
#[derive(Debug, Clone)]
34+
struct ImportInfo {
35+
/// Map from local name to full path
36+
imports: std::collections::HashMap<String, String>,
37+
}
38+
39+
impl ImportInfo {
40+
fn new() -> Self {
41+
Self {
42+
imports: std::collections::HashMap::new(),
43+
}
44+
}
45+
46+
fn add_import(&mut self, local_name: String, full_path: String) {
47+
self.imports.insert(local_name, full_path);
48+
}
49+
50+
fn resolve_type(&self, type_name: &str) -> Option<&String> {
51+
self.imports.get(type_name)
52+
}
53+
}
54+
55+
/// Parse use statements to understand imported items
56+
fn parse_use_statements(module_items: &[Item]) -> ImportInfo {
57+
let mut import_info = ImportInfo::new();
58+
59+
for item in module_items {
60+
if let Item::Use(item_use) = item {
61+
extract_imports_from_use_tree(&item_use.tree, &mut import_info, String::new());
62+
}
63+
}
64+
65+
import_info
66+
}
67+
68+
/// Recursively extract imports from a use tree
69+
fn extract_imports_from_use_tree(
70+
use_tree: &UseTree,
71+
import_info: &mut ImportInfo,
72+
base_path: String,
73+
) {
74+
match use_tree {
75+
UseTree::Path(use_path) => {
76+
let new_base = if base_path.is_empty() {
77+
use_path.ident.to_string()
78+
} else {
79+
format!("{}::{}", base_path, use_path.ident)
80+
};
81+
extract_imports_from_use_tree(&use_path.tree, import_info, new_base);
82+
}
83+
UseTree::Name(use_name) => {
84+
let local_name = use_name.ident.to_string();
85+
let full_path = if base_path.is_empty() {
86+
local_name.clone()
87+
} else {
88+
format!("{}::{}", base_path, local_name)
89+
};
90+
import_info.add_import(local_name, full_path);
91+
}
92+
UseTree::Rename(use_rename) => {
93+
let local_name = use_rename.rename.to_string();
94+
let full_path = if base_path.is_empty() {
95+
use_rename.ident.to_string()
96+
} else {
97+
format!("{}::{}", base_path, use_rename.ident)
98+
};
99+
import_info.add_import(local_name, full_path);
100+
}
101+
UseTree::Glob(_) => {
102+
// For glob imports, we can't easily resolve specific items
103+
// In this case, we'll add a special marker for the base path
104+
import_info.add_import("*".to_string(), base_path);
105+
}
106+
UseTree::Group(use_group) => {
107+
for tree in &use_group.items {
108+
extract_imports_from_use_tree(tree, import_info, base_path.clone());
109+
}
110+
}
111+
}
112+
}
113+
32114
/// Extract instruction parameter names from #[instruction(...)] attribute
33115
fn extract_instruction_param_names(attrs: &[Attribute]) -> Vec<String> {
34116
for attr in attrs {
@@ -79,8 +161,40 @@ fn has_accounts_derive(attrs: &[Attribute]) -> bool {
79161
})
80162
}
81163

82-
/// Scan module items to find account structs that initialize the given account type
83-
fn find_account_seeds_for_type(
164+
/// Enhanced function to find seeds that can handle imports and re-exports
165+
fn find_account_seeds_for_type_enhanced(
166+
module_items: &[Item],
167+
account_type: &Ident,
168+
import_info: &ImportInfo,
169+
) -> Result<Option<SeedInfo>> {
170+
// First, try the original approach (look for directly defined structs)
171+
if let Some(seeds_info) = find_account_seeds_for_type_original(module_items, account_type)? {
172+
return Ok(Some(seeds_info));
173+
}
174+
175+
// Then, try to find imported or re-exported structs
176+
for item in module_items {
177+
if let Item::Struct(item_struct) = item {
178+
if has_accounts_derive(&item_struct.attrs) {
179+
if let syn::Fields::Named(fields) = &item_struct.fields {
180+
for field in &fields.named {
181+
// Try to match field types with account_type, considering imports
182+
if let Some(seeds_info) =
183+
extract_seeds_from_field_enhanced(field, account_type, import_info)?
184+
{
185+
return Ok(Some(seeds_info));
186+
}
187+
}
188+
}
189+
}
190+
}
191+
}
192+
193+
Ok(None)
194+
}
195+
196+
/// Original seed finding function (for direct definitions)
197+
fn find_account_seeds_for_type_original(
84198
module_items: &[Item],
85199
account_type: &Ident,
86200
) -> Result<Option<SeedInfo>> {
@@ -138,6 +252,65 @@ fn matches_account_type(ty: &syn::Type, target_type: &Ident) -> bool {
138252
}
139253
}
140254

255+
/// Enhanced type matching that considers imports
256+
fn matches_account_type_enhanced(
257+
ty: &syn::Type,
258+
target_type: &Ident,
259+
import_info: &ImportInfo,
260+
) -> bool {
261+
// First try the original approach
262+
if matches_account_type(ty, target_type) {
263+
return true;
264+
}
265+
266+
// Then try with import resolution
267+
match ty {
268+
syn::Type::Path(type_path) => {
269+
if let Some(last_segment) = type_path.path.segments.last() {
270+
let account_type_names = ["Account", "AccountLoader", "InterfaceAccount"];
271+
if account_type_names.contains(&&*last_segment.ident.to_string()) {
272+
if let syn::PathArguments::AngleBracketed(args) = &last_segment.arguments {
273+
for arg in &args.args {
274+
if let syn::GenericArgument::Type(syn::Type::Path(inner_type)) = arg {
275+
// Try to resolve the inner type through imports
276+
if let Some(inner_segment) = inner_type.path.segments.last() {
277+
let inner_type_name = inner_segment.ident.to_string();
278+
279+
// Direct match
280+
if inner_segment.ident == *target_type {
281+
return true;
282+
}
283+
284+
// Check if it matches through imports
285+
if let Some(resolved_path) =
286+
import_info.resolve_type(&inner_type_name)
287+
{
288+
if resolved_path.ends_with(&target_type.to_string()) {
289+
return true;
290+
}
291+
}
292+
293+
// Check if target_type matches through imports
294+
let target_type_name = target_type.to_string();
295+
if let Some(resolved_target) =
296+
import_info.resolve_type(&target_type_name)
297+
{
298+
if resolved_target.ends_with(&inner_type_name) {
299+
return true;
300+
}
301+
}
302+
}
303+
}
304+
}
305+
}
306+
}
307+
}
308+
false
309+
}
310+
_ => false,
311+
}
312+
}
313+
141314
/// Parse account attribute to extract init, seeds, and bump information using proper AST parsing
142315
fn parse_account_attribute(attr: &Attribute) -> Result<Option<(bool, Vec<Expr>, bool)>> {
143316
if !attr.path().is_ident("account") {
@@ -262,6 +435,48 @@ fn extract_seeds_from_field(field: &Field, target_type: &Ident) -> Result<Option
262435
Ok(None)
263436
}
264437

438+
/// Enhanced version of extract_seeds_from_field that handles imports
439+
fn extract_seeds_from_field_enhanced(
440+
field: &Field,
441+
target_type: &Ident,
442+
import_info: &ImportInfo,
443+
) -> Result<Option<SeedInfo>> {
444+
// First try the original approach
445+
if let Some(seeds_info) = extract_seeds_from_field(field, target_type)? {
446+
return Ok(Some(seeds_info));
447+
}
448+
449+
// Then try with import resolution
450+
let field_type_matches = matches_account_type_enhanced(&field.ty, target_type, import_info);
451+
452+
if !field_type_matches {
453+
return Ok(None);
454+
}
455+
456+
// Look for account attribute with init and seeds
457+
for attr in &field.attrs {
458+
if let Some((has_init, seeds, has_bump)) = parse_account_attribute(attr)? {
459+
if has_init && !seeds.is_empty() {
460+
let bump_field = if has_bump {
461+
Some(format_ident!("bump"))
462+
} else {
463+
None
464+
};
465+
466+
// Convert instruction parameter references to account field references
467+
let converted_seeds = convert_seed_parameters(seeds, target_type)?;
468+
469+
return Ok(Some(SeedInfo {
470+
seeds: converted_seeds,
471+
bump_field,
472+
}));
473+
}
474+
}
475+
}
476+
477+
Ok(None)
478+
}
479+
265480
/// Generate compress instructions for the specified account types (Anchor version)
266481
pub(crate) fn add_compressible_instructions(
267482
args: TokenStream,
@@ -277,6 +492,9 @@ pub(crate) fn add_compressible_instructions(
277492
// Get the module content
278493
let content = module.content.as_mut().unwrap();
279494

495+
// Parse import information to handle multi-file structures
496+
let import_info = parse_use_statements(&content.1);
497+
280498
// Collect all struct names for the enum
281499
let struct_names: Vec<_> = ident_list.idents.iter().cloned().collect();
282500

@@ -604,13 +822,14 @@ pub(crate) fn add_compressible_instructions(
604822
let compress_accounts_name = format_ident!("Compress{}", struct_name);
605823

606824
// Find seeds for this account type from existing account structs
607-
let seeds_info = find_account_seeds_for_type(&content.1, &struct_name)?
825+
let seeds_info = find_account_seeds_for_type_enhanced(&content.1, &struct_name, &import_info)?
608826
.ok_or_else(|| syn::Error::new_spanned(
609827
&struct_name,
610828
format!(
611829
"No account struct found with 'init' constraint and seeds for type '{}'. \
612830
Please ensure you have an account struct (with #[derive(Accounts)]) that initializes \
613-
this account type with seeds specified in the #[account(init, seeds = [...], ...)] attribute.",
831+
this account type with seeds specified in the #[account(init, seeds = [...], ...)] attribute. \
832+
This can be in the same module or imported via 'use' statements.",
614833
struct_name
615834
)
616835
))?;

sdk-tests/anchor-compressible-derived/Cargo.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,17 @@ light-sdk = { workspace = true, features = ["anchor", "idl-build", "anchor-discr
2424
light-sdk-types = { workspace = true }
2525
light-sdk-macros = { workspace = true }
2626
light-hasher = { workspace = true, features = ["solana"] }
27-
solana-program = { workspace = true }
2827
light-macros = { workspace = true, features = ["solana"] }
28+
solana-program = { workspace = true }
2929
borsh = { workspace = true }
3030
light-compressed-account = { workspace = true, features = ["solana"] }
3131
anchor-lang = { workspace = true, features = ["idl-build"] }
3232

3333
[dev-dependencies]
34-
light-program-test = { workspace = true, features = ["devenv"] }
35-
light-client = { workspace = true, features = ["devenv"] }
34+
light-program-test = { workspace = true, features = ["v2"] }
35+
light-client = { workspace = true, features = ["v2"] }
3636
light-compressible-client = { workspace = true, features = ["anchor"] }
37-
light-test-utils = { workspace = true, features = ["devenv"] }
37+
light-test-utils = { workspace = true }
3838
tokio = { workspace = true }
3939
solana-sdk = { workspace = true }
4040

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
use anchor_lang::prelude::*;
2+
3+
use crate::state::UserRecord;
4+
5+
#[derive(Accounts)]
6+
pub struct CreateRecord<'info> {
7+
#[account(mut)]
8+
pub user: Signer<'info>,
9+
#[account(
10+
init,
11+
payer = user,
12+
// Manually add 10 bytes! Discriminator + owner + string len + name +
13+
// score + option<compression_info>
14+
space = 8 + 32 + 4 + 32 + 8 + 10,
15+
seeds = [b"user_record", user.key().as_ref()],
16+
bump,
17+
)]
18+
pub user_record: Account<'info, UserRecord>,
19+
/// UNCHECKED: checked via config.
20+
#[account(mut)]
21+
pub rent_recipient: AccountInfo<'info>,
22+
/// The global config account
23+
/// UNCHECKED: checked via load_checked.
24+
pub config: AccountInfo<'info>,
25+
pub system_program: Program<'info, System>,
26+
}

0 commit comments

Comments
 (0)