Skip to content

Commit

Permalink
fix: handle return statements and enhance call to throw logic
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelangeloio committed Jan 25, 2024
1 parent 2ae9f64 commit 3b76363
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 22 deletions.
102 changes: 89 additions & 13 deletions crates/does-it-throw/src/call_finder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ extern crate swc_ecma_ast;
extern crate swc_ecma_parser;
extern crate swc_ecma_visit;

use std::collections::HashSet;
use std::collections::hash_map::DefaultHasher;
use std::collections::{HashMap, HashSet};
use std::hash::{Hash, Hasher};

use swc_ecma_ast::{
Expand Down Expand Up @@ -47,6 +48,7 @@ impl Hash for CallToThrowMap {
pub struct InstantiationsMap {
pub class_name: String,
pub variable_name: String,
pub instantiation_span: Span,
}

impl PartialEq for InstantiationsMap {
Expand Down Expand Up @@ -87,12 +89,36 @@ pub struct CallFinder {
pub calls: HashSet<CallToThrowMap>,
pub functions_with_throws: HashSet<ThrowMap>,
pub current_class_name: Option<String>,
pub instantiations: HashSet<InstantiationsMap>,
pub instantiations: HashMap<String, InstantiationsMap>,
pub function_name_stack: Vec<String>,
pub object_property_stack: Vec<String>,
pub processed_calls: HashSet<u64>,
}

impl CallFinder {
// use the lo and hi of the Span to generate a unique ID so that we don't
// record duplicate calls
// We also use the instantiation span for instances to ensure uniqueness
// for calls to methods on instances
fn generate_unique_call_id(&self, call: &CallExpr) -> u64 {
let mut hasher = DefaultHasher::new();

if let Callee::Expr(expr) = &call.callee {
if let Expr::Member(member_expr) = &**expr {
if let Expr::Ident(instance_ident) = &*member_expr.obj {
if let Some(inst_map) = self.instantiations.get(&instance_ident.sym.to_string()) {
// Use the instantiation span for instances
inst_map.instantiation_span.hash(&mut hasher);
}
}
}
}

// Also hash the call span to ensure uniqueness
call.span.hash(&mut hasher);
hasher.finish()
}

fn handle_bin_expr(&mut self, bin_expr: &BinExpr) {
if let Expr::Call(call_expr) = &*bin_expr.left {
self.visit_call_expr(call_expr);
Expand Down Expand Up @@ -204,6 +230,12 @@ impl Visit for CallFinder {

fn visit_call_expr(&mut self, call: &CallExpr) {
if let Callee::Expr(expr) = &call.callee {
let call_id = self.generate_unique_call_id(call);
// If we've already processed this call, skip it
if !self.processed_calls.insert(call_id) {
// This call was already processed, so return early
return;
}
match &**expr {
Expr::Member(member_expr) => {
let mut possible_class_name = None;
Expand All @@ -214,10 +246,8 @@ impl Visit for CallFinder {
}
if let Some(ref obj_name) = possible_class_name {
let mut new_class_name = None;
for instantiation in self.instantiations.iter() {
if &instantiation.variable_name == obj_name {
new_class_name = Some(instantiation.class_name.clone());
}
if let Some(instantiation) = self.instantiations.get(obj_name) {
new_class_name = Some(instantiation.class_name.clone());
}
if let Some(class_name) = new_class_name {
possible_class_name = Some(class_name);
Expand Down Expand Up @@ -322,6 +352,7 @@ impl Visit for CallFinder {
_ => {}
}
}

}

fn visit_var_declarator(&mut self, var_declarator: &VarDeclarator) {
Expand All @@ -332,9 +363,11 @@ impl Visit for CallFinder {
let class_name = expr.sym.to_string();
if let Some(var_ident) = &var_declarator.name.as_ident() {
let var_name = var_ident.sym.to_string();
self.instantiations.insert(InstantiationsMap {
let instantiation_span = var_ident.span;
self.instantiations.insert(var_name.clone(), InstantiationsMap {
class_name,
variable_name: var_name,
instantiation_span,
});
}
}
Expand Down Expand Up @@ -439,16 +472,59 @@ impl Visit for CallFinder {
self.visit_stmt(alt);
}
}
Stmt::Return(return_stmt) => {
if let Some(expr) = &return_stmt.arg {
self.visit_expr(expr);
}
// Handle returning an object expression
if let Some(block_stmt) = &return_stmt.arg.as_ref().and_then(|arg| match arg.as_ref() {
Expr::Object(object_expr) => Some(Box::new(object_expr.clone())),
_ => None,
}) {
for prop_or_spread in &block_stmt.props {
match prop_or_spread {
swc_ecma_ast::PropOrSpread::Prop(boxed_prop) => {
match &**boxed_prop {
swc_ecma_ast::Prop::KeyValue(key_value_prop) => {
// Handle KeyValue
self.visit_expr(&key_value_prop.value);
}
swc_ecma_ast::Prop::Assign(assign_prop) => {
// Handle Assign
self.visit_expr(&assign_prop.value);
}
swc_ecma_ast::Prop::Getter(getter_prop) => {
// Handle Getter
if let Some(body) = &getter_prop.body {
for stmt in &body.stmts {
self.visit_stmt(stmt);
}
}
}
swc_ecma_ast::Prop::Setter(setter_prop) => {
// Handle Setter
if let Some(body) = &setter_prop.body {
for stmt in &body.stmts {
self.visit_stmt(stmt);
}
}
}
_ => {}
}
}
swc_ecma_ast::PropOrSpread::Spread(spread) => {
// Handle Spread syntax
self.visit_expr(&spread.expr);
}
_ => {}

Check warning on line 519 in crates/does-it-throw/src/call_finder.rs

View workflow job for this annotation

GitHub Actions / build-and-test

unreachable pattern
}
}
}
}
_ => {
// For other kinds of statements, we continue with the default implementation (for now)
swc_ecma_visit::visit_stmt(self, stmt);
}
}
}

fn visit_expr(&mut self, expr: &Expr) {
if let Expr::Call(call_expr) = expr {
self.visit_call_expr(call_expr)
}
}
}
8 changes: 8 additions & 0 deletions crates/does-it-throw/src/fixtures/callExpr.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,11 @@ SomeRandomCall2(() => {
connection.oneWithASecondArg({}, () => {
throw new Error('hi khue')
})

const testGetter = {
get test() {
SomeThrow()
}
}

const array = [SomeThrow(), SomeThrow2()]
32 changes: 32 additions & 0 deletions crates/does-it-throw/src/fixtures/returnStatement.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// @ts-nocheck
const someThrow = () => {
if (something) {
while (true) {
throw new Error("oh no");
}
} else {
for (let i = 0; i < 10; i++) {
throw new Error("oh no");
}
}
}
class Test {
badMethod() {
throw new Error("oh no");
}
}

const callToSomeThrow = () => {
const testMethod = new Test();
return {
test: someThrow(),
testing: () => someThrow(),
array: [someThrow(), someThrow()],
object: { test: someThrow() },
class: testMethod.badMethod(),
}
}

function test() {
return someThrow();
}
5 changes: 3 additions & 2 deletions crates/does-it-throw/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ extern crate swc_ecma_ast;
extern crate swc_ecma_parser;
extern crate swc_ecma_visit;

use std::collections::HashSet;
use std::collections::{HashMap, HashSet};

use std::vec;

Expand Down Expand Up @@ -94,10 +94,11 @@ pub fn analyze_code(
};
throw_collector.visit_module(&module);
let mut call_collector = CallFinder {
processed_calls: HashSet::new(),
functions_with_throws: throw_collector.functions_with_throws.clone(),
calls: HashSet::new(),
current_class_name: None,
instantiations: HashSet::new(),
instantiations: HashMap::new(),
function_name_stack: vec![],
object_property_stack: vec![],
};
Expand Down
71 changes: 64 additions & 7 deletions crates/does-it-throw/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ mod integration_tests {

// general result assertions
assert_eq!(result.functions_with_throws.len(), 5);
assert_eq!(result.calls_to_throws.len(), 7);
assert_eq!(result.calls_to_throws.len(), 10);
assert_eq!(result.imported_identifier_usages.len(), 0);
assert_eq!(result.import_sources.len(), 0);

Expand Down Expand Up @@ -467,18 +467,20 @@ mod integration_tests {
.iter()
.map(|c| c.id.clone())
.collect();

fn calls_to_throws_contains(calls_to_throws: &Vec<String>, call_to_throw: &str) -> bool {
calls_to_throws.iter().any(|c| c == call_to_throw)
}
[
"NOT_SET-onInitialized",
"NOT_SET-SomeRandomCall2",
"NOT_SET-<anonymous>",
"connection-<anonymous>",
"NOT_SET-<anonymous>",
"NOT_SET-onInitialized",
"NOT_SET-<anonymous>",
"NOT_SET-SomeRandomCall2",
"NOT_SET-SomeRandomCall2",
"NOT_SET-onInitialized",
"NOT_SET-<anonymous>",
"connection-<anonymous>",
"NOT_SET-<anonymous>",
]
.iter()
.for_each(|f| assert!(calls_to_throws_contains(&calls_to_throws, f)));
Expand Down Expand Up @@ -939,7 +941,7 @@ mod integration_tests {
}

#[test]
fn test_should_not_include_throws_for_ignore_statements () {
fn test_should_not_include_throws_for_ignore_statements() {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let file_path = format!("{}/src/fixtures/ignoreStatements.ts", manifest_dir);
// Read sample code from file
Expand All @@ -961,7 +963,7 @@ mod integration_tests {
}

#[test]
fn test_should_not_include_throws_for_ignore_statements_js () {
fn test_should_not_include_throws_for_ignore_statements_js() {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let file_path = format!("{}/src/fixtures/ignoreStatements.js", manifest_dir);
// Read sample code from file
Expand All @@ -981,4 +983,59 @@ mod integration_tests {
assert_eq!(result.functions_with_throws.len(), 0);
assert_eq!(result.calls_to_throws.len(), 0);
}

#[test]
fn test_should_handle_return_statements() {
let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let file_path = format!("{}/src/fixtures/returnStatement.ts", manifest_dir);
// Read sample code from file
let sample_code = fs::read_to_string(file_path).expect("Something went wrong reading the file");
let ignore_statements = vec![
"@it-throws".to_string(),
"@it-throws-ignore".to_string(),
"@some-random-ignore".to_string(),
];
let cm: Lrc<SourceMap> = Default::default();
let user_settings = UserSettings {
include_try_statement_throws: true,
ignore_statements,
};
let (result, _cm) = analyze_code(&sample_code, cm, &user_settings);

assert_eq!(result.functions_with_throws.len(), 2);
assert_eq!(result.calls_to_throws.len(), 7);

let function_names: Vec<String> = result
.functions_with_throws
.clone()
.into_iter()
.map(|f| f.function_or_method_name)
.collect();

fn function_names_contains(function_names: &Vec<String>, function_name: &str) -> bool {
function_names.iter().any(|f| f == function_name)
}

["someThrow", "badMethod"]
.iter()
.for_each(|f| assert!(function_names_contains(&function_names, f)));

let calls_to_throws: Vec<String> = result.calls_to_throws.into_iter().map(|c| c.id).collect();

fn calls_to_throws_contains(calls_to_throws: &Vec<String>, call_to_throw: &str) -> bool {
calls_to_throws.iter().any(|c| c == call_to_throw)
}

[
"NOT_SET-callToSomeThrow",
"NOT_SET-callToSomeThrow",
"NOT_SET-callToSomeThrow",
"NOT_SET-test",
"NOT_SET-callToSomeThrow",
"NOT_SET-callToSomeThrow",
"Test-callToSomeThrow",
]
.iter()
.for_each(|f| assert!(calls_to_throws_contains(&calls_to_throws, f)));
}
}

0 comments on commit 3b76363

Please sign in to comment.