From 663a0824ad4e0cbadd6efbe3c3eb0993f46da01d Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Tue, 14 Nov 2023 02:17:47 +0900 Subject: [PATCH] feat: var-kwargs codegen --- crates/erg_compiler/codegen.rs | 20 +++++++++++++++----- crates/erg_compiler/context/register.rs | 10 +++++----- crates/erg_compiler/ty/codeobj.rs | 3 ++- crates/erg_parser/parse.rs | 10 ++++++++-- tests/should_err/var_kwargs.er | 4 ++++ tests/should_ok/var_kwargs.er | 3 +++ tests/test.rs | 10 ++++++++++ 7 files changed, 47 insertions(+), 13 deletions(-) create mode 100644 tests/should_err/var_kwargs.er create mode 100644 tests/should_ok/var_kwargs.er diff --git a/crates/erg_compiler/codegen.rs b/crates/erg_compiler/codegen.rs index 2e899357c..5a3390495 100644 --- a/crates/erg_compiler/codegen.rs +++ b/crates/erg_compiler/codegen.rs @@ -1026,6 +1026,14 @@ impl PyCodeGenerator { .iter() .map(|p| (p.inspect().map(|s| &s[..]).unwrap_or("_"), &p.sig.vi)), ) + .chain(if let Some(kw_var_args) = ¶ms.kw_var_params { + vec![( + kw_var_args.inspect().map(|s| &s[..]).unwrap_or("_"), + &kw_var_args.vi, + )] + } else { + vec![] + }) .enumerate() .map(|(i, (s, vi))| { if s == "_" { @@ -1405,11 +1413,13 @@ impl PyCodeGenerator { self.stack_dec_n(defaults_len - 1); make_function_flag += MakeFunctionFlags::Defaults as usize; } - let flags = if sig.params.var_params.is_some() { - CodeObjFlags::VarArgs as u32 - } else { - 0 - }; + let mut flags = 0; + if sig.params.var_params.is_some() { + flags += CodeObjFlags::VarArgs as u32; + } + if sig.params.kw_var_params.is_some() { + flags += CodeObjFlags::VarKeywords as u32; + } let code = self.emit_block( body.block, sig.params.guards, diff --git a/crates/erg_compiler/context/register.rs b/crates/erg_compiler/context/register.rs index 672dd7698..2b9518a18 100644 --- a/crates/erg_compiler/context/register.rs +++ b/crates/erg_compiler/context/register.rs @@ -390,16 +390,16 @@ impl Context { opt_decl_t, &mut dummy_tv_cache, Normal, - kind, + kind.clone(), false, ) { Ok(ty) => (ty, TyCheckErrors::empty()), Err((ty, errs)) => (ty, errs), }; - let spec_t = if is_var_params { - unknown_len_array_t(spec_t) - } else { - spec_t + let spec_t = match kind { + ParamKind::VarParams => unknown_len_array_t(spec_t), + ParamKind::KwVarParams => str_dict_t(spec_t), + _ => spec_t, }; if &name.inspect()[..] == "self" { self.type_self_param(&sig.raw.pat, name, &spec_t, &mut errs); diff --git a/crates/erg_compiler/ty/codeobj.rs b/crates/erg_compiler/ty/codeobj.rs index e297ecd03..b9b2241cc 100644 --- a/crates/erg_compiler/ty/codeobj.rs +++ b/crates/erg_compiler/ty/codeobj.rs @@ -280,8 +280,9 @@ impl CodeObj { ) -> Self { let name = name.into(); let var_args_defined = (flags & CodeObjFlags::VarArgs as u32 != 0) as u32; + let kw_var_args_defined = (flags & CodeObjFlags::VarKeywords as u32 != 0) as u32; Self { - argcount: params.len() as u32 - var_args_defined, + argcount: params.len() as u32 - var_args_defined - kw_var_args_defined, posonlyargcount: 0, kwonlyargcount: 0, nlocals: params.len() as u32, diff --git a/crates/erg_parser/parse.rs b/crates/erg_parser/parse.rs index 1221757bc..445996cdf 100644 --- a/crates/erg_parser/parse.rs +++ b/crates/erg_parser/parse.rs @@ -2415,7 +2415,8 @@ impl Parser { debug_exit_info!(self); Ok(call_or_acc) } - Some(t) if t.is(PreStar) => { + Some(t) if t.is(PreStar) || t.is(PreDblStar) => { + let kind = t.kind; let _ = self.lpop(); let expr = self .try_reduce_expr(false, in_type_args, in_brace, false) @@ -2430,8 +2431,13 @@ impl Parser { } self.stack_dec(fn_name!()) })?; + let arg = match kind { + PreStar => ArgKind::Var(PosArg::new(expr)), + PreDblStar => ArgKind::KwVar(PosArg::new(expr)), + _ => switch_unreachable!(), + }; let tuple = self - .try_reduce_nonempty_tuple(ArgKind::Var(PosArg::new(expr)), false) + .try_reduce_nonempty_tuple(arg, false) .map_err(|_| self.stack_dec(fn_name!()))?; debug_exit_info!(self); Ok(Expr::Tuple(tuple)) diff --git a/tests/should_err/var_kwargs.er b/tests/should_err/var_kwargs.er new file mode 100644 index 000000000..8619a3fd2 --- /dev/null +++ b/tests/should_err/var_kwargs.er @@ -0,0 +1,4 @@ +kw_var(**x: Int) = x["::a"] + x["::b"] + +_ = kw_var(a:="1", b:=2) # ERR +_ = kw_var(a:=1, b:="2") # ERR diff --git a/tests/should_ok/var_kwargs.er b/tests/should_ok/var_kwargs.er new file mode 100644 index 000000000..fe93c3e83 --- /dev/null +++ b/tests/should_ok/var_kwargs.er @@ -0,0 +1,3 @@ +kw_var(**x: Int) = x["::a"] + x["::b"] + +assert kw_var(a:=1, b:=2) == 3 diff --git a/tests/test.rs b/tests/test.rs index d88b2cbf1..01e0d73c2 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -410,6 +410,11 @@ fn exec_var_args() -> Result<(), ()> { expect_success("tests/should_ok/var_args.er", 0) } +#[test] +fn exec_var_kwargs() -> Result<(), ()> { + expect_success("tests/should_ok/var_kwargs.er", 0) +} + #[test] fn exec_with() -> Result<(), ()> { expect_success("examples/with.er", 0) @@ -621,6 +626,11 @@ fn exec_var_args_err() -> Result<(), ()> { expect_failure("tests/should_err/var_args.er", 0, 3) } +#[test] +fn exec_var_kwargs_err() -> Result<(), ()> { + expect_failure("tests/should_err/var_kwargs.er", 0, 2) +} + #[test] fn exec_visibility() -> Result<(), ()> { expect_failure("tests/should_err/visibility.er", 2, 7)